forked from cardosofelipe/fast-next-template
feat(mcp): implement Git Operations MCP server with Gitea provider
Implements the Git Operations MCP server (Issue #58) providing: Core features: - GitPython wrapper for local repository operations (clone, commit, push, pull, diff, log) - Branch management (create, delete, list, checkout) - Workspace isolation per project with file-based locking - Gitea provider for remote PR operations MCP Tools (17 registered): - clone_repository, git_status, create_branch, list_branches - checkout, commit, push, pull, diff, log - create_pull_request, get_pull_request, list_pull_requests - merge_pull_request, get_workspace, lock_workspace, unlock_workspace Technical details: - FastMCP + FastAPI with JSON-RPC 2.0 protocol - pydantic-settings for configuration (env prefix: GIT_OPS_) - Comprehensive error hierarchy with structured codes - 131 tests passing with 67% coverage - Async operations via ThreadPoolExecutor Closes: #105, #106, #107, #108, #109 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
67
mcp-servers/git-ops/Dockerfile
Normal file
67
mcp-servers/git-ops/Dockerfile
Normal file
@@ -0,0 +1,67 @@
|
||||
# Git Operations MCP Server Dockerfile
|
||||
# Multi-stage build for smaller production image
|
||||
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install uv for fast package management
|
||||
RUN pip install --no-cache-dir uv
|
||||
|
||||
# Create app directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml .
|
||||
|
||||
# Install dependencies with uv
|
||||
RUN uv pip install --system --no-cache .
|
||||
|
||||
# Production stage
|
||||
FROM python:3.12-slim
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
git \
|
||||
openssh-client \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd --create-home --shell /bin/bash syndarix
|
||||
|
||||
# Create workspace directory
|
||||
RUN mkdir -p /var/syndarix/workspaces && chown -R syndarix:syndarix /var/syndarix
|
||||
|
||||
# Create app directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy installed packages from builder
|
||||
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
|
||||
COPY --from=builder /usr/local/bin /usr/local/bin
|
||||
|
||||
# Copy application code
|
||||
COPY --chown=syndarix:syndarix . .
|
||||
|
||||
# Set Python path
|
||||
ENV PYTHONPATH=/app
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# Configure git for the container
|
||||
RUN git config --global --add safe.directory '*'
|
||||
|
||||
# Switch to non-root user
|
||||
USER syndarix
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8003
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD python -c "import httpx; httpx.get('http://localhost:8003/health').raise_for_status()" || exit 1
|
||||
|
||||
# Run the server
|
||||
CMD ["python", "server.py"]
|
||||
179
mcp-servers/git-ops/__init__.py
Normal file
179
mcp-servers/git-ops/__init__.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Git Operations MCP Server.
|
||||
|
||||
Provides git repository management, branching, commits, and PR workflows
|
||||
for Syndarix AI agents.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
from config import Settings, get_settings, is_test_mode, reset_settings
|
||||
from exceptions import (
|
||||
APIError,
|
||||
AuthenticationError,
|
||||
BranchExistsError,
|
||||
BranchNotFoundError,
|
||||
CheckoutError,
|
||||
CloneError,
|
||||
CommitError,
|
||||
CredentialError,
|
||||
CredentialNotFoundError,
|
||||
DirtyWorkspaceError,
|
||||
ErrorCode,
|
||||
GitError,
|
||||
GitOpsError,
|
||||
InvalidRefError,
|
||||
MergeConflictError,
|
||||
PRError,
|
||||
PRNotFoundError,
|
||||
ProviderError,
|
||||
ProviderNotFoundError,
|
||||
PullError,
|
||||
PushError,
|
||||
WorkspaceError,
|
||||
WorkspaceLockedError,
|
||||
WorkspaceNotFoundError,
|
||||
WorkspaceSizeExceededError,
|
||||
)
|
||||
from models import (
|
||||
BranchInfo,
|
||||
BranchRequest,
|
||||
BranchResult,
|
||||
CheckoutRequest,
|
||||
CheckoutResult,
|
||||
CloneRequest,
|
||||
CloneResult,
|
||||
CommitInfo,
|
||||
CommitRequest,
|
||||
CommitResult,
|
||||
CreatePRRequest,
|
||||
CreatePRResult,
|
||||
DiffHunk,
|
||||
DiffRequest,
|
||||
DiffResult,
|
||||
FileChange,
|
||||
FileChangeType,
|
||||
FileDiff,
|
||||
GetPRRequest,
|
||||
GetPRResult,
|
||||
GetWorkspaceRequest,
|
||||
GetWorkspaceResult,
|
||||
HealthStatus,
|
||||
ListBranchesRequest,
|
||||
ListBranchesResult,
|
||||
ListPRsRequest,
|
||||
ListPRsResult,
|
||||
LockWorkspaceRequest,
|
||||
LockWorkspaceResult,
|
||||
LogRequest,
|
||||
LogResult,
|
||||
MergePRRequest,
|
||||
MergePRResult,
|
||||
MergeStrategy,
|
||||
PRInfo,
|
||||
ProviderStatus,
|
||||
ProviderType,
|
||||
PRState,
|
||||
PullRequest,
|
||||
PullResult,
|
||||
PushRequest,
|
||||
PushResult,
|
||||
StatusRequest,
|
||||
StatusResult,
|
||||
UnlockWorkspaceRequest,
|
||||
UnlockWorkspaceResult,
|
||||
UpdatePRRequest,
|
||||
UpdatePRResult,
|
||||
WorkspaceInfo,
|
||||
WorkspaceState,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Version
|
||||
"__version__",
|
||||
# Config
|
||||
"Settings",
|
||||
"get_settings",
|
||||
"reset_settings",
|
||||
"is_test_mode",
|
||||
# Error codes
|
||||
"ErrorCode",
|
||||
# Exceptions
|
||||
"GitOpsError",
|
||||
"WorkspaceError",
|
||||
"WorkspaceNotFoundError",
|
||||
"WorkspaceLockedError",
|
||||
"WorkspaceSizeExceededError",
|
||||
"GitError",
|
||||
"CloneError",
|
||||
"CheckoutError",
|
||||
"CommitError",
|
||||
"PushError",
|
||||
"PullError",
|
||||
"MergeConflictError",
|
||||
"BranchExistsError",
|
||||
"BranchNotFoundError",
|
||||
"InvalidRefError",
|
||||
"DirtyWorkspaceError",
|
||||
"ProviderError",
|
||||
"AuthenticationError",
|
||||
"ProviderNotFoundError",
|
||||
"PRError",
|
||||
"PRNotFoundError",
|
||||
"APIError",
|
||||
"CredentialError",
|
||||
"CredentialNotFoundError",
|
||||
# Enums
|
||||
"FileChangeType",
|
||||
"MergeStrategy",
|
||||
"PRState",
|
||||
"ProviderType",
|
||||
"WorkspaceState",
|
||||
# Dataclasses
|
||||
"FileChange",
|
||||
"BranchInfo",
|
||||
"CommitInfo",
|
||||
"DiffHunk",
|
||||
"FileDiff",
|
||||
"PRInfo",
|
||||
"WorkspaceInfo",
|
||||
# Request/Response models
|
||||
"CloneRequest",
|
||||
"CloneResult",
|
||||
"StatusRequest",
|
||||
"StatusResult",
|
||||
"BranchRequest",
|
||||
"BranchResult",
|
||||
"ListBranchesRequest",
|
||||
"ListBranchesResult",
|
||||
"CheckoutRequest",
|
||||
"CheckoutResult",
|
||||
"CommitRequest",
|
||||
"CommitResult",
|
||||
"PushRequest",
|
||||
"PushResult",
|
||||
"PullRequest",
|
||||
"PullResult",
|
||||
"DiffRequest",
|
||||
"DiffResult",
|
||||
"LogRequest",
|
||||
"LogResult",
|
||||
"CreatePRRequest",
|
||||
"CreatePRResult",
|
||||
"GetPRRequest",
|
||||
"GetPRResult",
|
||||
"ListPRsRequest",
|
||||
"ListPRsResult",
|
||||
"MergePRRequest",
|
||||
"MergePRResult",
|
||||
"UpdatePRRequest",
|
||||
"UpdatePRResult",
|
||||
"GetWorkspaceRequest",
|
||||
"GetWorkspaceResult",
|
||||
"LockWorkspaceRequest",
|
||||
"LockWorkspaceResult",
|
||||
"UnlockWorkspaceRequest",
|
||||
"UnlockWorkspaceResult",
|
||||
"HealthStatus",
|
||||
"ProviderStatus",
|
||||
]
|
||||
155
mcp-servers/git-ops/config.py
Normal file
155
mcp-servers/git-ops/config.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
Configuration for Git Operations MCP Server.
|
||||
|
||||
Uses pydantic-settings for environment variable loading.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment."""
|
||||
|
||||
# Server settings
|
||||
host: str = Field(default="0.0.0.0", description="Server host")
|
||||
port: int = Field(default=8003, description="Server port")
|
||||
debug: bool = Field(default=False, description="Debug mode")
|
||||
|
||||
# Workspace settings
|
||||
workspace_base_path: Path = Field(
|
||||
default=Path("/var/syndarix/workspaces"),
|
||||
description="Base path for git workspaces",
|
||||
)
|
||||
workspace_max_size_gb: float = Field(
|
||||
default=10.0,
|
||||
description="Maximum size per workspace in GB",
|
||||
)
|
||||
workspace_stale_days: int = Field(
|
||||
default=7,
|
||||
description="Days after which unused workspace is considered stale",
|
||||
)
|
||||
workspace_lock_timeout: int = Field(
|
||||
default=300,
|
||||
description="Workspace lock timeout in seconds",
|
||||
)
|
||||
|
||||
# Git settings
|
||||
git_timeout: int = Field(
|
||||
default=120,
|
||||
description="Default timeout for git operations in seconds",
|
||||
)
|
||||
git_clone_timeout: int = Field(
|
||||
default=600,
|
||||
description="Timeout for clone operations in seconds",
|
||||
)
|
||||
git_author_name: str = Field(
|
||||
default="Syndarix Agent",
|
||||
description="Default author name for commits",
|
||||
)
|
||||
git_author_email: str = Field(
|
||||
default="agent@syndarix.ai",
|
||||
description="Default author email for commits",
|
||||
)
|
||||
git_max_diff_lines: int = Field(
|
||||
default=10000,
|
||||
description="Maximum lines in diff output",
|
||||
)
|
||||
|
||||
# Redis settings (for distributed locking)
|
||||
redis_url: str = Field(
|
||||
default="redis://localhost:6379/0",
|
||||
description="Redis connection URL",
|
||||
)
|
||||
|
||||
# Provider settings
|
||||
gitea_base_url: str = Field(
|
||||
default="",
|
||||
description="Gitea API base URL (e.g., https://gitea.example.com)",
|
||||
)
|
||||
gitea_token: str = Field(
|
||||
default="",
|
||||
description="Gitea API token",
|
||||
)
|
||||
github_token: str = Field(
|
||||
default="",
|
||||
description="GitHub API token",
|
||||
)
|
||||
github_api_url: str = Field(
|
||||
default="https://api.github.com",
|
||||
description="GitHub API URL (for Enterprise)",
|
||||
)
|
||||
gitlab_token: str = Field(
|
||||
default="",
|
||||
description="GitLab API token",
|
||||
)
|
||||
gitlab_url: str = Field(
|
||||
default="https://gitlab.com",
|
||||
description="GitLab URL (for self-hosted)",
|
||||
)
|
||||
|
||||
# Rate limiting
|
||||
rate_limit_requests: int = Field(
|
||||
default=100,
|
||||
description="Max API requests per minute per provider",
|
||||
)
|
||||
rate_limit_window: int = Field(
|
||||
default=60,
|
||||
description="Rate limit window in seconds",
|
||||
)
|
||||
|
||||
# Retry settings
|
||||
retry_attempts: int = Field(
|
||||
default=3,
|
||||
description="Number of retry attempts for failed operations",
|
||||
)
|
||||
retry_delay: float = Field(
|
||||
default=1.0,
|
||||
description="Initial retry delay in seconds",
|
||||
)
|
||||
retry_max_delay: float = Field(
|
||||
default=30.0,
|
||||
description="Maximum retry delay in seconds",
|
||||
)
|
||||
|
||||
# Security settings
|
||||
allowed_hosts: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Allowed git host domains (empty = all)",
|
||||
)
|
||||
max_clone_size_mb: int = Field(
|
||||
default=500,
|
||||
description="Maximum repository size for clone in MB",
|
||||
)
|
||||
enable_force_push: bool = Field(
|
||||
default=False,
|
||||
description="Allow force push operations",
|
||||
)
|
||||
|
||||
model_config = {"env_prefix": "GIT_OPS_", "env_file": ".env", "extra": "ignore"}
|
||||
|
||||
|
||||
# Global settings instance (lazy initialization)
|
||||
_settings: Settings | None = None
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""Get the global settings instance."""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
_settings = Settings()
|
||||
return _settings
|
||||
|
||||
|
||||
def reset_settings() -> None:
|
||||
"""Reset the global settings (for testing)."""
|
||||
global _settings
|
||||
_settings = None
|
||||
|
||||
|
||||
def is_test_mode() -> bool:
|
||||
"""Check if running in test mode."""
|
||||
return os.getenv("IS_TEST", "").lower() in ("true", "1", "yes")
|
||||
361
mcp-servers/git-ops/exceptions.py
Normal file
361
mcp-servers/git-ops/exceptions.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Exception hierarchy for Git Operations MCP Server.
|
||||
|
||||
Provides structured error handling with error codes for MCP responses.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ErrorCode(str, Enum):
|
||||
"""Error codes for Git Operations errors."""
|
||||
|
||||
# General errors (1xxx)
|
||||
INTERNAL_ERROR = "GIT_1000"
|
||||
INVALID_REQUEST = "GIT_1001"
|
||||
NOT_FOUND = "GIT_1002"
|
||||
PERMISSION_DENIED = "GIT_1003"
|
||||
TIMEOUT = "GIT_1004"
|
||||
RATE_LIMITED = "GIT_1005"
|
||||
|
||||
# Workspace errors (2xxx)
|
||||
WORKSPACE_NOT_FOUND = "GIT_2000"
|
||||
WORKSPACE_LOCKED = "GIT_2001"
|
||||
WORKSPACE_SIZE_EXCEEDED = "GIT_2002"
|
||||
WORKSPACE_CREATE_FAILED = "GIT_2003"
|
||||
WORKSPACE_DELETE_FAILED = "GIT_2004"
|
||||
|
||||
# Git operation errors (3xxx)
|
||||
CLONE_FAILED = "GIT_3000"
|
||||
CHECKOUT_FAILED = "GIT_3001"
|
||||
COMMIT_FAILED = "GIT_3002"
|
||||
PUSH_FAILED = "GIT_3003"
|
||||
PULL_FAILED = "GIT_3004"
|
||||
MERGE_CONFLICT = "GIT_3005"
|
||||
BRANCH_EXISTS = "GIT_3006"
|
||||
BRANCH_NOT_FOUND = "GIT_3007"
|
||||
INVALID_REF = "GIT_3008"
|
||||
DIRTY_WORKSPACE = "GIT_3009"
|
||||
UNCOMMITTED_CHANGES = "GIT_3010"
|
||||
FETCH_FAILED = "GIT_3011"
|
||||
RESET_FAILED = "GIT_3012"
|
||||
|
||||
# Provider errors (4xxx)
|
||||
PROVIDER_ERROR = "GIT_4000"
|
||||
PROVIDER_AUTH_FAILED = "GIT_4001"
|
||||
PROVIDER_NOT_FOUND = "GIT_4002"
|
||||
PR_CREATE_FAILED = "GIT_4003"
|
||||
PR_MERGE_FAILED = "GIT_4004"
|
||||
PR_NOT_FOUND = "GIT_4005"
|
||||
API_ERROR = "GIT_4006"
|
||||
|
||||
# Credential errors (5xxx)
|
||||
CREDENTIAL_ERROR = "GIT_5000"
|
||||
CREDENTIAL_NOT_FOUND = "GIT_5001"
|
||||
CREDENTIAL_INVALID = "GIT_5002"
|
||||
SSH_KEY_ERROR = "GIT_5003"
|
||||
|
||||
|
||||
class GitOpsError(Exception):
|
||||
"""Base exception for Git Operations errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.INTERNAL_ERROR,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.details = details or {}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for MCP response."""
|
||||
result = {
|
||||
"error": self.message,
|
||||
"code": self.code.value,
|
||||
}
|
||||
if self.details:
|
||||
result["details"] = self.details
|
||||
return result
|
||||
|
||||
|
||||
# Workspace Errors
|
||||
|
||||
|
||||
class WorkspaceError(GitOpsError):
|
||||
"""Base exception for workspace-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.WORKSPACE_NOT_FOUND,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, code, details)
|
||||
|
||||
|
||||
class WorkspaceNotFoundError(WorkspaceError):
|
||||
"""Workspace does not exist."""
|
||||
|
||||
def __init__(self, project_id: str) -> None:
|
||||
super().__init__(
|
||||
f"Workspace not found for project: {project_id}",
|
||||
ErrorCode.WORKSPACE_NOT_FOUND,
|
||||
{"project_id": project_id},
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceLockedError(WorkspaceError):
|
||||
"""Workspace is locked by another operation."""
|
||||
|
||||
def __init__(self, project_id: str, holder: str | None = None) -> None:
|
||||
details: dict[str, Any] = {"project_id": project_id}
|
||||
if holder:
|
||||
details["locked_by"] = holder
|
||||
super().__init__(
|
||||
f"Workspace is locked for project: {project_id}",
|
||||
ErrorCode.WORKSPACE_LOCKED,
|
||||
details,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceSizeExceededError(WorkspaceError):
|
||||
"""Workspace size limit exceeded."""
|
||||
|
||||
def __init__(self, project_id: str, current_size: float, max_size: float) -> None:
|
||||
super().__init__(
|
||||
f"Workspace size limit exceeded for project: {project_id}",
|
||||
ErrorCode.WORKSPACE_SIZE_EXCEEDED,
|
||||
{
|
||||
"project_id": project_id,
|
||||
"current_size_gb": current_size,
|
||||
"max_size_gb": max_size,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Git Operation Errors
|
||||
|
||||
|
||||
class GitError(GitOpsError):
|
||||
"""Base exception for git operation errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.INTERNAL_ERROR,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, code, details)
|
||||
|
||||
|
||||
class CloneError(GitError):
|
||||
"""Failed to clone repository."""
|
||||
|
||||
def __init__(self, repo_url: str, reason: str) -> None:
|
||||
super().__init__(
|
||||
f"Failed to clone repository: {reason}",
|
||||
ErrorCode.CLONE_FAILED,
|
||||
{"repo_url": repo_url, "reason": reason},
|
||||
)
|
||||
|
||||
|
||||
class CheckoutError(GitError):
|
||||
"""Failed to checkout branch or ref."""
|
||||
|
||||
def __init__(self, ref: str, reason: str) -> None:
|
||||
super().__init__(
|
||||
f"Failed to checkout '{ref}': {reason}",
|
||||
ErrorCode.CHECKOUT_FAILED,
|
||||
{"ref": ref, "reason": reason},
|
||||
)
|
||||
|
||||
|
||||
class CommitError(GitError):
|
||||
"""Failed to commit changes."""
|
||||
|
||||
def __init__(self, reason: str) -> None:
|
||||
super().__init__(
|
||||
f"Failed to commit: {reason}",
|
||||
ErrorCode.COMMIT_FAILED,
|
||||
{"reason": reason},
|
||||
)
|
||||
|
||||
|
||||
class PushError(GitError):
|
||||
"""Failed to push to remote."""
|
||||
|
||||
def __init__(self, branch: str, reason: str) -> None:
|
||||
super().__init__(
|
||||
f"Failed to push branch '{branch}': {reason}",
|
||||
ErrorCode.PUSH_FAILED,
|
||||
{"branch": branch, "reason": reason},
|
||||
)
|
||||
|
||||
|
||||
class PullError(GitError):
|
||||
"""Failed to pull from remote."""
|
||||
|
||||
def __init__(self, branch: str, reason: str) -> None:
|
||||
super().__init__(
|
||||
f"Failed to pull branch '{branch}': {reason}",
|
||||
ErrorCode.PULL_FAILED,
|
||||
{"branch": branch, "reason": reason},
|
||||
)
|
||||
|
||||
|
||||
class MergeConflictError(GitError):
|
||||
"""Merge conflict detected."""
|
||||
|
||||
def __init__(self, conflicting_files: list[str]) -> None:
|
||||
super().__init__(
|
||||
f"Merge conflict detected in {len(conflicting_files)} files",
|
||||
ErrorCode.MERGE_CONFLICT,
|
||||
{"conflicting_files": conflicting_files},
|
||||
)
|
||||
|
||||
|
||||
class BranchExistsError(GitError):
|
||||
"""Branch already exists."""
|
||||
|
||||
def __init__(self, branch_name: str) -> None:
|
||||
super().__init__(
|
||||
f"Branch already exists: {branch_name}",
|
||||
ErrorCode.BRANCH_EXISTS,
|
||||
{"branch": branch_name},
|
||||
)
|
||||
|
||||
|
||||
class BranchNotFoundError(GitError):
|
||||
"""Branch does not exist."""
|
||||
|
||||
def __init__(self, branch_name: str) -> None:
|
||||
super().__init__(
|
||||
f"Branch not found: {branch_name}",
|
||||
ErrorCode.BRANCH_NOT_FOUND,
|
||||
{"branch": branch_name},
|
||||
)
|
||||
|
||||
|
||||
class InvalidRefError(GitError):
|
||||
"""Invalid git reference."""
|
||||
|
||||
def __init__(self, ref: str) -> None:
|
||||
super().__init__(
|
||||
f"Invalid git reference: {ref}",
|
||||
ErrorCode.INVALID_REF,
|
||||
{"ref": ref},
|
||||
)
|
||||
|
||||
|
||||
class DirtyWorkspaceError(GitError):
|
||||
"""Workspace has uncommitted changes."""
|
||||
|
||||
def __init__(self, modified_files: list[str]) -> None:
|
||||
super().__init__(
|
||||
f"Workspace has {len(modified_files)} uncommitted changes",
|
||||
ErrorCode.DIRTY_WORKSPACE,
|
||||
{"modified_files": modified_files[:10]}, # Limit to first 10
|
||||
)
|
||||
|
||||
|
||||
# Provider Errors
|
||||
|
||||
|
||||
class ProviderError(GitOpsError):
|
||||
"""Base exception for provider-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.PROVIDER_ERROR,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, code, details)
|
||||
|
||||
|
||||
class AuthenticationError(ProviderError):
|
||||
"""Authentication with provider failed."""
|
||||
|
||||
def __init__(self, provider: str, reason: str) -> None:
|
||||
super().__init__(
|
||||
f"Authentication failed with {provider}: {reason}",
|
||||
ErrorCode.PROVIDER_AUTH_FAILED,
|
||||
{"provider": provider, "reason": reason},
|
||||
)
|
||||
|
||||
|
||||
class ProviderNotFoundError(ProviderError):
|
||||
"""Provider not configured or recognized."""
|
||||
|
||||
def __init__(self, provider: str) -> None:
|
||||
super().__init__(
|
||||
f"Provider not found or not configured: {provider}",
|
||||
ErrorCode.PROVIDER_NOT_FOUND,
|
||||
{"provider": provider},
|
||||
)
|
||||
|
||||
|
||||
class PRError(ProviderError):
|
||||
"""Pull request operation failed."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.PR_CREATE_FAILED,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, code, details)
|
||||
|
||||
|
||||
class PRNotFoundError(PRError):
|
||||
"""Pull request not found."""
|
||||
|
||||
def __init__(self, pr_number: int, repo: str) -> None:
|
||||
super().__init__(
|
||||
f"Pull request #{pr_number} not found in {repo}",
|
||||
ErrorCode.PR_NOT_FOUND,
|
||||
{"pr_number": pr_number, "repo": repo},
|
||||
)
|
||||
|
||||
|
||||
class APIError(ProviderError):
|
||||
"""Provider API error."""
|
||||
|
||||
def __init__(
|
||||
self, provider: str, status_code: int, message: str
|
||||
) -> None:
|
||||
super().__init__(
|
||||
f"{provider} API error ({status_code}): {message}",
|
||||
ErrorCode.API_ERROR,
|
||||
{"provider": provider, "status_code": status_code, "message": message},
|
||||
)
|
||||
|
||||
|
||||
# Credential Errors
|
||||
|
||||
|
||||
class CredentialError(GitOpsError):
|
||||
"""Base exception for credential-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.CREDENTIAL_ERROR,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, code, details)
|
||||
|
||||
|
||||
class CredentialNotFoundError(CredentialError):
|
||||
"""Credential not found."""
|
||||
|
||||
def __init__(self, credential_type: str, identifier: str) -> None:
|
||||
super().__init__(
|
||||
f"{credential_type} credential not found: {identifier}",
|
||||
ErrorCode.CREDENTIAL_NOT_FOUND,
|
||||
{"type": credential_type, "identifier": identifier},
|
||||
)
|
||||
1112
mcp-servers/git-ops/git_wrapper.py
Normal file
1112
mcp-servers/git-ops/git_wrapper.py
Normal file
File diff suppressed because it is too large
Load Diff
678
mcp-servers/git-ops/models.py
Normal file
678
mcp-servers/git-ops/models.py
Normal file
@@ -0,0 +1,678 @@
|
||||
"""
|
||||
Data models for Git Operations MCP Server.
|
||||
|
||||
Defines data structures for git operations, workspace management,
|
||||
and provider interactions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FileChangeType(str, Enum):
|
||||
"""Types of file changes in git."""
|
||||
|
||||
ADDED = "added"
|
||||
MODIFIED = "modified"
|
||||
DELETED = "deleted"
|
||||
RENAMED = "renamed"
|
||||
COPIED = "copied"
|
||||
UNTRACKED = "untracked"
|
||||
IGNORED = "ignored"
|
||||
|
||||
|
||||
class MergeStrategy(str, Enum):
|
||||
"""Merge strategies for pull requests."""
|
||||
|
||||
MERGE = "merge" # Create a merge commit
|
||||
SQUASH = "squash" # Squash and merge
|
||||
REBASE = "rebase" # Rebase and merge
|
||||
|
||||
|
||||
class PRState(str, Enum):
|
||||
"""Pull request states."""
|
||||
|
||||
OPEN = "open"
|
||||
CLOSED = "closed"
|
||||
MERGED = "merged"
|
||||
|
||||
|
||||
class ProviderType(str, Enum):
|
||||
"""Supported git providers."""
|
||||
|
||||
GITEA = "gitea"
|
||||
GITHUB = "github"
|
||||
GITLAB = "gitlab"
|
||||
|
||||
|
||||
class WorkspaceState(str, Enum):
|
||||
"""Workspace lifecycle states."""
|
||||
|
||||
INITIALIZING = "initializing"
|
||||
READY = "ready"
|
||||
LOCKED = "locked"
|
||||
STALE = "stale"
|
||||
DELETED = "deleted"
|
||||
|
||||
|
||||
# Dataclasses for internal data structures
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileChange:
|
||||
"""A file change in git status."""
|
||||
|
||||
path: str
|
||||
change_type: FileChangeType
|
||||
old_path: str | None = None # For renames
|
||||
additions: int = 0
|
||||
deletions: int = 0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"path": self.path,
|
||||
"change_type": self.change_type.value,
|
||||
"old_path": self.old_path,
|
||||
"additions": self.additions,
|
||||
"deletions": self.deletions,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BranchInfo:
|
||||
"""Information about a git branch."""
|
||||
|
||||
name: str
|
||||
is_current: bool = False
|
||||
is_remote: bool = False
|
||||
tracking_branch: str | None = None
|
||||
commit_sha: str | None = None
|
||||
commit_message: str | None = None
|
||||
ahead: int = 0
|
||||
behind: int = 0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"is_current": self.is_current,
|
||||
"is_remote": self.is_remote,
|
||||
"tracking_branch": self.tracking_branch,
|
||||
"commit_sha": self.commit_sha,
|
||||
"commit_message": self.commit_message,
|
||||
"ahead": self.ahead,
|
||||
"behind": self.behind,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommitInfo:
|
||||
"""Information about a git commit."""
|
||||
|
||||
sha: str
|
||||
short_sha: str
|
||||
message: str
|
||||
author_name: str
|
||||
author_email: str
|
||||
authored_date: datetime
|
||||
committer_name: str
|
||||
committer_email: str
|
||||
committed_date: datetime
|
||||
parents: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"sha": self.sha,
|
||||
"short_sha": self.short_sha,
|
||||
"message": self.message,
|
||||
"author_name": self.author_name,
|
||||
"author_email": self.author_email,
|
||||
"authored_date": self.authored_date.isoformat(),
|
||||
"committer_name": self.committer_name,
|
||||
"committer_email": self.committer_email,
|
||||
"committed_date": self.committed_date.isoformat(),
|
||||
"parents": self.parents,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffHunk:
|
||||
"""A hunk of diff content."""
|
||||
|
||||
old_start: int
|
||||
old_lines: int
|
||||
new_start: int
|
||||
new_lines: int
|
||||
content: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"old_start": self.old_start,
|
||||
"old_lines": self.old_lines,
|
||||
"new_start": self.new_start,
|
||||
"new_lines": self.new_lines,
|
||||
"content": self.content,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileDiff:
|
||||
"""Diff for a single file."""
|
||||
|
||||
path: str
|
||||
change_type: FileChangeType
|
||||
old_path: str | None = None
|
||||
hunks: list[DiffHunk] = field(default_factory=list)
|
||||
additions: int = 0
|
||||
deletions: int = 0
|
||||
is_binary: bool = False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"path": self.path,
|
||||
"change_type": self.change_type.value,
|
||||
"old_path": self.old_path,
|
||||
"hunks": [h.to_dict() for h in self.hunks],
|
||||
"additions": self.additions,
|
||||
"deletions": self.deletions,
|
||||
"is_binary": self.is_binary,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PRInfo:
|
||||
"""Information about a pull request."""
|
||||
|
||||
number: int
|
||||
title: str
|
||||
body: str
|
||||
state: PRState
|
||||
source_branch: str
|
||||
target_branch: str
|
||||
author: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
merged_at: datetime | None = None
|
||||
closed_at: datetime | None = None
|
||||
url: str | None = None
|
||||
labels: list[str] = field(default_factory=list)
|
||||
assignees: list[str] = field(default_factory=list)
|
||||
reviewers: list[str] = field(default_factory=list)
|
||||
mergeable: bool | None = None
|
||||
draft: bool = False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"number": self.number,
|
||||
"title": self.title,
|
||||
"body": self.body,
|
||||
"state": self.state.value,
|
||||
"source_branch": self.source_branch,
|
||||
"target_branch": self.target_branch,
|
||||
"author": self.author,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
"merged_at": self.merged_at.isoformat() if self.merged_at else None,
|
||||
"closed_at": self.closed_at.isoformat() if self.closed_at else None,
|
||||
"url": self.url,
|
||||
"labels": self.labels,
|
||||
"assignees": self.assignees,
|
||||
"reviewers": self.reviewers,
|
||||
"mergeable": self.mergeable,
|
||||
"draft": self.draft,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkspaceInfo:
|
||||
"""Information about a project workspace."""
|
||||
|
||||
project_id: str
|
||||
path: str
|
||||
state: WorkspaceState
|
||||
repo_url: str | None = None
|
||||
current_branch: str | None = None
|
||||
last_accessed: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
size_bytes: int = 0
|
||||
lock_holder: str | None = None
|
||||
lock_expires: datetime | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"project_id": self.project_id,
|
||||
"path": self.path,
|
||||
"state": self.state.value,
|
||||
"repo_url": self.repo_url,
|
||||
"current_branch": self.current_branch,
|
||||
"last_accessed": self.last_accessed.isoformat(),
|
||||
"size_bytes": self.size_bytes,
|
||||
"lock_holder": self.lock_holder,
|
||||
"lock_expires": self.lock_expires.isoformat() if self.lock_expires else None,
|
||||
}
|
||||
|
||||
|
||||
# Pydantic Request/Response Models
|
||||
|
||||
|
||||
class CloneRequest(BaseModel):
|
||||
"""Request to clone a repository."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
repo_url: str = Field(..., description="Repository URL to clone")
|
||||
branch: str | None = Field(default=None, description="Branch to checkout after clone")
|
||||
depth: int | None = Field(
|
||||
default=None, ge=1, description="Shallow clone depth (None = full clone)"
|
||||
)
|
||||
|
||||
|
||||
class CloneResult(BaseModel):
|
||||
"""Result of a clone operation."""
|
||||
|
||||
success: bool = Field(..., description="Whether clone succeeded")
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
workspace_path: str = Field(..., description="Path to cloned workspace")
|
||||
branch: str = Field(..., description="Current branch after clone")
|
||||
commit_sha: str = Field(..., description="HEAD commit SHA")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class StatusRequest(BaseModel):
|
||||
"""Request for git status."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
include_untracked: bool = Field(default=True, description="Include untracked files")
|
||||
|
||||
|
||||
class StatusResult(BaseModel):
|
||||
"""Result of a status operation."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
branch: str = Field(..., description="Current branch")
|
||||
commit_sha: str = Field(..., description="HEAD commit SHA")
|
||||
is_clean: bool = Field(..., description="Whether working tree is clean")
|
||||
staged: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="Staged changes"
|
||||
)
|
||||
unstaged: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="Unstaged changes"
|
||||
)
|
||||
untracked: list[str] = Field(default_factory=list, description="Untracked files")
|
||||
ahead: int = Field(default=0, description="Commits ahead of upstream")
|
||||
behind: int = Field(default=0, description="Commits behind upstream")
|
||||
|
||||
|
||||
class BranchRequest(BaseModel):
|
||||
"""Request for branch operations."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
branch_name: str = Field(..., description="Branch name")
|
||||
from_ref: str | None = Field(
|
||||
default=None, description="Reference to create branch from"
|
||||
)
|
||||
checkout: bool = Field(default=True, description="Checkout after creation")
|
||||
|
||||
|
||||
class BranchResult(BaseModel):
|
||||
"""Result of a branch operation."""
|
||||
|
||||
success: bool = Field(..., description="Whether operation succeeded")
|
||||
branch: str = Field(..., description="Branch name")
|
||||
commit_sha: str | None = Field(default=None, description="HEAD commit SHA")
|
||||
is_current: bool = Field(default=False, description="Whether branch is checked out")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class ListBranchesRequest(BaseModel):
|
||||
"""Request to list branches."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
include_remote: bool = Field(default=False, description="Include remote branches")
|
||||
|
||||
|
||||
class ListBranchesResult(BaseModel):
|
||||
"""Result of listing branches."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
current_branch: str = Field(..., description="Currently checked out branch")
|
||||
local_branches: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="Local branches"
|
||||
)
|
||||
remote_branches: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="Remote branches"
|
||||
)
|
||||
|
||||
|
||||
class CheckoutRequest(BaseModel):
|
||||
"""Request to checkout a branch or ref."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
ref: str = Field(..., description="Branch, tag, or commit to checkout")
|
||||
create_branch: bool = Field(default=False, description="Create new branch")
|
||||
force: bool = Field(default=False, description="Force checkout (discard changes)")
|
||||
|
||||
|
||||
class CheckoutResult(BaseModel):
|
||||
"""Result of a checkout operation."""
|
||||
|
||||
success: bool = Field(..., description="Whether checkout succeeded")
|
||||
ref: str = Field(..., description="Checked out reference")
|
||||
commit_sha: str | None = Field(default=None, description="HEAD commit SHA")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class CommitRequest(BaseModel):
|
||||
"""Request to create a commit."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
message: str = Field(..., description="Commit message")
|
||||
files: list[str] | None = Field(
|
||||
default=None, description="Files to commit (None = all staged)"
|
||||
)
|
||||
author_name: str | None = Field(default=None, description="Author name override")
|
||||
author_email: str | None = Field(default=None, description="Author email override")
|
||||
allow_empty: bool = Field(default=False, description="Allow empty commit")
|
||||
|
||||
|
||||
class CommitResult(BaseModel):
|
||||
"""Result of a commit operation."""
|
||||
|
||||
success: bool = Field(..., description="Whether commit succeeded")
|
||||
commit_sha: str | None = Field(default=None, description="New commit SHA")
|
||||
short_sha: str | None = Field(default=None, description="Short commit SHA")
|
||||
message: str | None = Field(default=None, description="Commit message")
|
||||
files_changed: int = Field(default=0, description="Number of files changed")
|
||||
insertions: int = Field(default=0, description="Lines added")
|
||||
deletions: int = Field(default=0, description="Lines removed")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class PushRequest(BaseModel):
|
||||
"""Request to push to remote."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
branch: str | None = Field(default=None, description="Branch to push (None = current)")
|
||||
remote: str = Field(default="origin", description="Remote name")
|
||||
force: bool = Field(default=False, description="Force push")
|
||||
set_upstream: bool = Field(default=True, description="Set upstream tracking")
|
||||
|
||||
|
||||
class PushResult(BaseModel):
|
||||
"""Result of a push operation."""
|
||||
|
||||
success: bool = Field(..., description="Whether push succeeded")
|
||||
branch: str = Field(..., description="Pushed branch")
|
||||
remote: str = Field(..., description="Remote name")
|
||||
commits_pushed: int = Field(default=0, description="Number of commits pushed")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class PullRequest(BaseModel):
|
||||
"""Request to pull from remote."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
branch: str | None = Field(default=None, description="Branch to pull (None = current)")
|
||||
remote: str = Field(default="origin", description="Remote name")
|
||||
rebase: bool = Field(default=False, description="Rebase instead of merge")
|
||||
|
||||
|
||||
class PullResult(BaseModel):
|
||||
"""Result of a pull operation."""
|
||||
|
||||
success: bool = Field(..., description="Whether pull succeeded")
|
||||
branch: str = Field(..., description="Pulled branch")
|
||||
commits_received: int = Field(default=0, description="New commits received")
|
||||
fast_forward: bool = Field(default=False, description="Was fast-forward")
|
||||
conflicts: list[str] = Field(
|
||||
default_factory=list, description="Conflicting files if any"
|
||||
)
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class DiffRequest(BaseModel):
|
||||
"""Request for diff."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
base: str | None = Field(default=None, description="Base reference (None = working tree)")
|
||||
head: str | None = Field(default=None, description="Head reference (None = HEAD)")
|
||||
files: list[str] | None = Field(default=None, description="Specific files to diff")
|
||||
context_lines: int = Field(default=3, ge=0, description="Context lines")
|
||||
|
||||
|
||||
class DiffResult(BaseModel):
|
||||
"""Result of a diff operation."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
base: str | None = Field(default=None, description="Base reference")
|
||||
head: str | None = Field(default=None, description="Head reference")
|
||||
files: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="File diffs"
|
||||
)
|
||||
total_additions: int = Field(default=0, description="Total lines added")
|
||||
total_deletions: int = Field(default=0, description="Total lines removed")
|
||||
files_changed: int = Field(default=0, description="Number of files changed")
|
||||
|
||||
|
||||
class LogRequest(BaseModel):
|
||||
"""Request for commit log."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
ref: str | None = Field(default=None, description="Reference to start from")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Max commits to return")
|
||||
skip: int = Field(default=0, ge=0, description="Commits to skip")
|
||||
path: str | None = Field(default=None, description="Filter by path")
|
||||
|
||||
|
||||
class LogResult(BaseModel):
|
||||
"""Result of a log operation."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
commits: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="Commit history"
|
||||
)
|
||||
total_commits: int = Field(default=0, description="Total commits in range")
|
||||
|
||||
|
||||
# PR Operations
|
||||
|
||||
|
||||
class CreatePRRequest(BaseModel):
|
||||
"""Request to create a pull request."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
title: str = Field(..., description="PR title")
|
||||
body: str = Field(default="", description="PR description")
|
||||
source_branch: str = Field(..., description="Source branch")
|
||||
target_branch: str = Field(default="main", description="Target branch")
|
||||
draft: bool = Field(default=False, description="Create as draft")
|
||||
labels: list[str] = Field(default_factory=list, description="Labels to add")
|
||||
assignees: list[str] = Field(default_factory=list, description="Assignees")
|
||||
reviewers: list[str] = Field(default_factory=list, description="Reviewers")
|
||||
|
||||
|
||||
class CreatePRResult(BaseModel):
|
||||
"""Result of creating a pull request."""
|
||||
|
||||
success: bool = Field(..., description="Whether creation succeeded")
|
||||
pr_number: int | None = Field(default=None, description="PR number")
|
||||
pr_url: str | None = Field(default=None, description="PR URL")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class GetPRRequest(BaseModel):
|
||||
"""Request to get a pull request."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
pr_number: int = Field(..., description="PR number")
|
||||
|
||||
|
||||
class GetPRResult(BaseModel):
|
||||
"""Result of getting a pull request."""
|
||||
|
||||
success: bool = Field(..., description="Whether fetch succeeded")
|
||||
pr: dict[str, Any] | None = Field(default=None, description="PR info")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class ListPRsRequest(BaseModel):
|
||||
"""Request to list pull requests."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
state: PRState | None = Field(default=None, description="Filter by state")
|
||||
author: str | None = Field(default=None, description="Filter by author")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Max PRs to return")
|
||||
|
||||
|
||||
class ListPRsResult(BaseModel):
|
||||
"""Result of listing pull requests."""
|
||||
|
||||
success: bool = Field(..., description="Whether list succeeded")
|
||||
pull_requests: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="PRs"
|
||||
)
|
||||
total_count: int = Field(default=0, description="Total matching PRs")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class MergePRRequest(BaseModel):
|
||||
"""Request to merge a pull request."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
pr_number: int = Field(..., description="PR number")
|
||||
merge_strategy: MergeStrategy = Field(
|
||||
default=MergeStrategy.MERGE, description="Merge strategy"
|
||||
)
|
||||
commit_message: str | None = Field(
|
||||
default=None, description="Custom merge commit message"
|
||||
)
|
||||
delete_branch: bool = Field(
|
||||
default=True, description="Delete source branch after merge"
|
||||
)
|
||||
|
||||
|
||||
class MergePRResult(BaseModel):
|
||||
"""Result of merging a pull request."""
|
||||
|
||||
success: bool = Field(..., description="Whether merge succeeded")
|
||||
merge_commit_sha: str | None = Field(default=None, description="Merge commit SHA")
|
||||
branch_deleted: bool = Field(default=False, description="Whether branch was deleted")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class UpdatePRRequest(BaseModel):
|
||||
"""Request to update a pull request."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
pr_number: int = Field(..., description="PR number")
|
||||
title: str | None = Field(default=None, description="New title")
|
||||
body: str | None = Field(default=None, description="New description")
|
||||
state: PRState | None = Field(default=None, description="New state")
|
||||
labels: list[str] | None = Field(default=None, description="Replace labels")
|
||||
assignees: list[str] | None = Field(default=None, description="Replace assignees")
|
||||
|
||||
|
||||
class UpdatePRResult(BaseModel):
|
||||
"""Result of updating a pull request."""
|
||||
|
||||
success: bool = Field(..., description="Whether update succeeded")
|
||||
pr: dict[str, Any] | None = Field(default=None, description="Updated PR info")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
# Workspace Operations
|
||||
|
||||
|
||||
class GetWorkspaceRequest(BaseModel):
|
||||
"""Request to get or create workspace."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
|
||||
|
||||
class GetWorkspaceResult(BaseModel):
|
||||
"""Result of getting workspace."""
|
||||
|
||||
success: bool = Field(..., description="Whether operation succeeded")
|
||||
workspace: dict[str, Any] | None = Field(default=None, description="Workspace info")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class LockWorkspaceRequest(BaseModel):
|
||||
"""Request to lock a workspace."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
agent_id: str = Field(..., description="Agent ID requesting lock")
|
||||
timeout: int = Field(default=300, ge=10, le=3600, description="Lock timeout seconds")
|
||||
|
||||
|
||||
class LockWorkspaceResult(BaseModel):
|
||||
"""Result of locking workspace."""
|
||||
|
||||
success: bool = Field(..., description="Whether lock acquired")
|
||||
lock_holder: str | None = Field(default=None, description="Current lock holder")
|
||||
lock_expires: str | None = Field(default=None, description="Lock expiry ISO timestamp")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class UnlockWorkspaceRequest(BaseModel):
|
||||
"""Request to unlock a workspace."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
agent_id: str = Field(..., description="Agent ID releasing lock")
|
||||
force: bool = Field(default=False, description="Force unlock (admin only)")
|
||||
|
||||
|
||||
class UnlockWorkspaceResult(BaseModel):
|
||||
"""Result of unlocking workspace."""
|
||||
|
||||
success: bool = Field(..., description="Whether unlock succeeded")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
# Health and Status
|
||||
|
||||
|
||||
class HealthStatus(BaseModel):
|
||||
"""Health status response."""
|
||||
|
||||
status: str = Field(..., description="Health status")
|
||||
version: str = Field(..., description="Server version")
|
||||
workspace_count: int = Field(default=0, description="Active workspaces")
|
||||
gitea_connected: bool = Field(default=False, description="Gitea connectivity")
|
||||
github_connected: bool = Field(default=False, description="GitHub connectivity")
|
||||
gitlab_connected: bool = Field(default=False, description="GitLab connectivity")
|
||||
redis_connected: bool = Field(default=False, description="Redis connectivity")
|
||||
|
||||
|
||||
class ProviderStatus(BaseModel):
|
||||
"""Provider connection status."""
|
||||
|
||||
provider: str = Field(..., description="Provider name")
|
||||
connected: bool = Field(..., description="Connection status")
|
||||
url: str | None = Field(default=None, description="Provider URL")
|
||||
user: str | None = Field(default=None, description="Authenticated user")
|
||||
error: str | None = Field(default=None, description="Error if not connected")
|
||||
10
mcp-servers/git-ops/providers/__init__.py
Normal file
10
mcp-servers/git-ops/providers/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Git provider implementations.
|
||||
|
||||
Provides adapters for different git hosting platforms (Gitea, GitHub, GitLab).
|
||||
"""
|
||||
|
||||
from .base import BaseProvider
|
||||
from .gitea import GiteaProvider
|
||||
|
||||
__all__ = ["BaseProvider", "GiteaProvider"]
|
||||
388
mcp-servers/git-ops/providers/base.py
Normal file
388
mcp-servers/git-ops/providers/base.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""
|
||||
Base provider interface for git hosting platforms.
|
||||
|
||||
Defines the abstract interface that all git providers must implement.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from models import (
|
||||
CreatePRResult,
|
||||
GetPRResult,
|
||||
ListPRsResult,
|
||||
MergePRResult,
|
||||
MergeStrategy,
|
||||
PRState,
|
||||
UpdatePRResult,
|
||||
)
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
"""
|
||||
Abstract base class for git hosting providers.
|
||||
|
||||
All providers (Gitea, GitHub, GitLab) must implement this interface.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Return the provider name (e.g., 'gitea', 'github')."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def is_connected(self) -> bool:
|
||||
"""Check if the provider is connected and authenticated."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_authenticated_user(self) -> str | None:
|
||||
"""Get the username of the authenticated user."""
|
||||
...
|
||||
|
||||
# Repository operations
|
||||
|
||||
@abstractmethod
|
||||
async def get_repo_info(
|
||||
self, owner: str, repo: str
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get repository information.
|
||||
|
||||
Args:
|
||||
owner: Repository owner/organization
|
||||
repo: Repository name
|
||||
|
||||
Returns:
|
||||
Repository info dict
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_default_branch(
|
||||
self, owner: str, repo: str
|
||||
) -> str:
|
||||
"""
|
||||
Get the default branch for a repository.
|
||||
|
||||
Args:
|
||||
owner: Repository owner/organization
|
||||
repo: Repository name
|
||||
|
||||
Returns:
|
||||
Default branch name
|
||||
"""
|
||||
...
|
||||
|
||||
# Pull Request operations
|
||||
|
||||
@abstractmethod
|
||||
async def create_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
title: str,
|
||||
body: str,
|
||||
source_branch: str,
|
||||
target_branch: str,
|
||||
draft: bool = False,
|
||||
labels: list[str] | None = None,
|
||||
assignees: list[str] | None = None,
|
||||
reviewers: list[str] | None = None,
|
||||
) -> CreatePRResult:
|
||||
"""
|
||||
Create a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
title: PR title
|
||||
body: PR description
|
||||
source_branch: Source branch name
|
||||
target_branch: Target branch name
|
||||
draft: Whether to create as draft
|
||||
labels: Labels to add
|
||||
assignees: Users to assign
|
||||
reviewers: Users to request review from
|
||||
|
||||
Returns:
|
||||
CreatePRResult with PR number and URL
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_pr(
|
||||
self, owner: str, repo: str, pr_number: int
|
||||
) -> GetPRResult:
|
||||
"""
|
||||
Get a pull request by number.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
|
||||
Returns:
|
||||
GetPRResult with PR details
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def list_prs(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
state: PRState | None = None,
|
||||
author: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ListPRsResult:
|
||||
"""
|
||||
List pull requests.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
state: Filter by state (open, closed, merged)
|
||||
author: Filter by author
|
||||
limit: Maximum PRs to return
|
||||
|
||||
Returns:
|
||||
ListPRsResult with list of PRs
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def merge_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
merge_strategy: MergeStrategy = MergeStrategy.MERGE,
|
||||
commit_message: str | None = None,
|
||||
delete_branch: bool = True,
|
||||
) -> MergePRResult:
|
||||
"""
|
||||
Merge a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
merge_strategy: Merge strategy to use
|
||||
commit_message: Custom merge commit message
|
||||
delete_branch: Whether to delete source branch
|
||||
|
||||
Returns:
|
||||
MergePRResult with merge status
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def update_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
title: str | None = None,
|
||||
body: str | None = None,
|
||||
state: PRState | None = None,
|
||||
labels: list[str] | None = None,
|
||||
assignees: list[str] | None = None,
|
||||
) -> UpdatePRResult:
|
||||
"""
|
||||
Update a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
title: New title
|
||||
body: New description
|
||||
state: New state (open, closed)
|
||||
labels: Replace labels
|
||||
assignees: Replace assignees
|
||||
|
||||
Returns:
|
||||
UpdatePRResult with updated PR info
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def close_pr(
|
||||
self, owner: str, repo: str, pr_number: int
|
||||
) -> UpdatePRResult:
|
||||
"""
|
||||
Close a pull request without merging.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
|
||||
Returns:
|
||||
UpdatePRResult with updated PR info
|
||||
"""
|
||||
...
|
||||
|
||||
# Branch operations via API (for operations that need to bypass local git)
|
||||
|
||||
@abstractmethod
|
||||
async def delete_remote_branch(
|
||||
self, owner: str, repo: str, branch: str
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a remote branch via API.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
branch: Branch name to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_branch(
|
||||
self, owner: str, repo: str, branch: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get branch information via API.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
branch: Branch name
|
||||
|
||||
Returns:
|
||||
Branch info dict or None if not found
|
||||
"""
|
||||
...
|
||||
|
||||
# Comment operations
|
||||
|
||||
@abstractmethod
|
||||
async def add_pr_comment(
|
||||
self, owner: str, repo: str, pr_number: int, body: str
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Add a comment to a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
body: Comment body
|
||||
|
||||
Returns:
|
||||
Created comment info
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def list_pr_comments(
|
||||
self, owner: str, repo: str, pr_number: int
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
List comments on a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
|
||||
Returns:
|
||||
List of comments
|
||||
"""
|
||||
...
|
||||
|
||||
# Label operations
|
||||
|
||||
@abstractmethod
|
||||
async def add_labels(
|
||||
self, owner: str, repo: str, pr_number: int, labels: list[str]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Add labels to a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
labels: Labels to add
|
||||
|
||||
Returns:
|
||||
Updated list of labels
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def remove_label(
|
||||
self, owner: str, repo: str, pr_number: int, label: str
|
||||
) -> list[str]:
|
||||
"""
|
||||
Remove a label from a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
label: Label to remove
|
||||
|
||||
Returns:
|
||||
Updated list of labels
|
||||
"""
|
||||
...
|
||||
|
||||
# Reviewer operations
|
||||
|
||||
@abstractmethod
|
||||
async def request_review(
|
||||
self, owner: str, repo: str, pr_number: int, reviewers: list[str]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Request review from users.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
reviewers: Usernames to request review from
|
||||
|
||||
Returns:
|
||||
List of reviewers requested
|
||||
"""
|
||||
...
|
||||
|
||||
# Utility methods
|
||||
|
||||
def parse_repo_url(self, repo_url: str) -> tuple[str, str]:
|
||||
"""
|
||||
Parse repository URL to extract owner and repo name.
|
||||
|
||||
Args:
|
||||
repo_url: Repository URL (HTTPS or SSH)
|
||||
|
||||
Returns:
|
||||
Tuple of (owner, repo)
|
||||
|
||||
Raises:
|
||||
ValueError: If URL cannot be parsed
|
||||
"""
|
||||
import re
|
||||
|
||||
# Handle SSH URLs: git@host:owner/repo.git
|
||||
ssh_match = re.match(r"git@[^:]+:([^/]+)/([^/]+?)(?:\.git)?$", repo_url)
|
||||
if ssh_match:
|
||||
return ssh_match.group(1), ssh_match.group(2)
|
||||
|
||||
# Handle HTTPS URLs: https://host/owner/repo.git
|
||||
https_match = re.match(
|
||||
r"https?://[^/]+/([^/]+)/([^/]+?)(?:\.git)?$", repo_url
|
||||
)
|
||||
if https_match:
|
||||
return https_match.group(1), https_match.group(2)
|
||||
|
||||
raise ValueError(f"Unable to parse repository URL: {repo_url}")
|
||||
723
mcp-servers/git-ops/providers/gitea.py
Normal file
723
mcp-servers/git-ops/providers/gitea.py
Normal file
@@ -0,0 +1,723 @@
|
||||
"""
|
||||
Gitea provider implementation.
|
||||
|
||||
Implements the BaseProvider interface for Gitea API operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from config import Settings, get_settings
|
||||
from exceptions import (
|
||||
APIError,
|
||||
AuthenticationError,
|
||||
PRNotFoundError,
|
||||
)
|
||||
from models import (
|
||||
CreatePRResult,
|
||||
GetPRResult,
|
||||
ListPRsResult,
|
||||
MergePRResult,
|
||||
MergeStrategy,
|
||||
PRInfo,
|
||||
PRState,
|
||||
UpdatePRResult,
|
||||
)
|
||||
|
||||
from .base import BaseProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GiteaProvider(BaseProvider):
|
||||
"""
|
||||
Gitea API provider implementation.
|
||||
|
||||
Supports all PR operations, branch operations, and repository queries.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | None = None,
|
||||
token: str | None = None,
|
||||
settings: Settings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Gitea provider.
|
||||
|
||||
Args:
|
||||
base_url: Gitea server URL (e.g., https://gitea.example.com)
|
||||
token: API token
|
||||
settings: Optional settings override
|
||||
"""
|
||||
self.settings = settings or get_settings()
|
||||
self.base_url = (base_url or self.settings.gitea_base_url).rstrip("/")
|
||||
self.token = token or self.settings.gitea_token
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._user: str | None = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "gitea"
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create HTTP client."""
|
||||
if self._client is None:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.token:
|
||||
headers["Authorization"] = f"token {self.token}"
|
||||
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=f"{self.base_url}/api/v1",
|
||||
headers=headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Make an API request.
|
||||
|
||||
Args:
|
||||
method: HTTP method
|
||||
path: API path
|
||||
**kwargs: Additional request arguments
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
|
||||
Raises:
|
||||
APIError: On API errors
|
||||
AuthenticationError: On auth failures
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
response = await client.request(method, path, **kwargs)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise AuthenticationError("gitea", "Invalid or expired token")
|
||||
|
||||
if response.status_code == 403:
|
||||
raise AuthenticationError(
|
||||
"gitea", "Insufficient permissions for this operation"
|
||||
)
|
||||
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
|
||||
if response.status_code >= 400:
|
||||
error_msg = response.text
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = error_data.get("message", error_msg)
|
||||
except Exception:
|
||||
pass
|
||||
raise APIError("gitea", response.status_code, error_msg)
|
||||
|
||||
if response.status_code == 204:
|
||||
return None
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
raise APIError("gitea", 0, f"Request failed: {e}")
|
||||
|
||||
async def is_connected(self) -> bool:
|
||||
"""Check if connected to Gitea."""
|
||||
if not self.base_url or not self.token:
|
||||
return False
|
||||
|
||||
try:
|
||||
result = await self._request("GET", "/user")
|
||||
return result is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_authenticated_user(self) -> str | None:
|
||||
"""Get the authenticated user's username."""
|
||||
if self._user:
|
||||
return self._user
|
||||
|
||||
try:
|
||||
result = await self._request("GET", "/user")
|
||||
if result:
|
||||
self._user = result.get("login") or result.get("username")
|
||||
return self._user
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
# Repository operations
|
||||
|
||||
async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]:
|
||||
"""Get repository information."""
|
||||
result = await self._request("GET", f"/repos/{owner}/{repo}")
|
||||
if result is None:
|
||||
raise APIError("gitea", 404, f"Repository not found: {owner}/{repo}")
|
||||
return result
|
||||
|
||||
async def get_default_branch(self, owner: str, repo: str) -> str:
|
||||
"""Get the default branch for a repository."""
|
||||
repo_info = await self.get_repo_info(owner, repo)
|
||||
return repo_info.get("default_branch", "main")
|
||||
|
||||
# Pull Request operations
|
||||
|
||||
async def create_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
title: str,
|
||||
body: str,
|
||||
source_branch: str,
|
||||
target_branch: str,
|
||||
draft: bool = False,
|
||||
labels: list[str] | None = None,
|
||||
assignees: list[str] | None = None,
|
||||
reviewers: list[str] | None = None,
|
||||
) -> CreatePRResult:
|
||||
"""Create a pull request."""
|
||||
try:
|
||||
data: dict[str, Any] = {
|
||||
"title": title,
|
||||
"body": body,
|
||||
"head": source_branch,
|
||||
"base": target_branch,
|
||||
}
|
||||
|
||||
# Note: Gitea doesn't have draft PR support in all versions
|
||||
# Draft support was added in Gitea 1.14+
|
||||
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/pulls",
|
||||
json=data,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return CreatePRResult(
|
||||
success=False,
|
||||
error="Failed to create pull request",
|
||||
)
|
||||
|
||||
pr_number = result["number"]
|
||||
|
||||
# Add labels if specified
|
||||
if labels:
|
||||
await self.add_labels(owner, repo, pr_number, labels)
|
||||
|
||||
# Add assignees if specified (via issue update)
|
||||
if assignees:
|
||||
await self._request(
|
||||
"PATCH",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}",
|
||||
json={"assignees": assignees},
|
||||
)
|
||||
|
||||
# Request reviewers if specified
|
||||
if reviewers:
|
||||
await self.request_review(owner, repo, pr_number, reviewers)
|
||||
|
||||
return CreatePRResult(
|
||||
success=True,
|
||||
pr_number=pr_number,
|
||||
pr_url=result.get("html_url"),
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
return CreatePRResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult:
|
||||
"""Get a pull request by number."""
|
||||
try:
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/pulls/{pr_number}",
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise PRNotFoundError(pr_number, f"{owner}/{repo}")
|
||||
|
||||
pr_info = self._parse_pr(result)
|
||||
|
||||
return GetPRResult(
|
||||
success=True,
|
||||
pr=pr_info.to_dict(),
|
||||
)
|
||||
|
||||
except PRNotFoundError:
|
||||
return GetPRResult(
|
||||
success=False,
|
||||
error=f"Pull request #{pr_number} not found",
|
||||
)
|
||||
except APIError as e:
|
||||
return GetPRResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def list_prs(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
state: PRState | None = None,
|
||||
author: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ListPRsResult:
|
||||
"""List pull requests."""
|
||||
try:
|
||||
params: dict[str, Any] = {
|
||||
"limit": limit,
|
||||
}
|
||||
|
||||
if state:
|
||||
# Gitea uses different state names
|
||||
if state == PRState.OPEN:
|
||||
params["state"] = "open"
|
||||
elif state == PRState.CLOSED or state == PRState.MERGED:
|
||||
params["state"] = "closed"
|
||||
else:
|
||||
params["state"] = "all"
|
||||
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/pulls",
|
||||
params=params,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return ListPRsResult(
|
||||
success=True,
|
||||
pull_requests=[],
|
||||
total_count=0,
|
||||
)
|
||||
|
||||
prs = []
|
||||
for pr_data in result:
|
||||
# Filter by author if specified
|
||||
if author:
|
||||
pr_author = pr_data.get("user", {}).get("login", "")
|
||||
if pr_author.lower() != author.lower():
|
||||
continue
|
||||
|
||||
# Filter merged PRs if looking specifically for merged
|
||||
if state == PRState.MERGED:
|
||||
if not pr_data.get("merged"):
|
||||
continue
|
||||
|
||||
pr_info = self._parse_pr(pr_data)
|
||||
prs.append(pr_info.to_dict())
|
||||
|
||||
return ListPRsResult(
|
||||
success=True,
|
||||
pull_requests=prs,
|
||||
total_count=len(prs),
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
return ListPRsResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def merge_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
merge_strategy: MergeStrategy = MergeStrategy.MERGE,
|
||||
commit_message: str | None = None,
|
||||
delete_branch: bool = True,
|
||||
) -> MergePRResult:
|
||||
"""Merge a pull request."""
|
||||
try:
|
||||
# Map merge strategy to Gitea's "Do" values
|
||||
do_map = {
|
||||
MergeStrategy.MERGE: "merge",
|
||||
MergeStrategy.SQUASH: "squash",
|
||||
MergeStrategy.REBASE: "rebase",
|
||||
}
|
||||
|
||||
data: dict[str, Any] = {
|
||||
"Do": do_map[merge_strategy],
|
||||
"delete_branch_after_merge": delete_branch,
|
||||
}
|
||||
|
||||
if commit_message:
|
||||
data["MergeTitleField"] = commit_message.split("\n")[0]
|
||||
if "\n" in commit_message:
|
||||
data["MergeMessageField"] = "\n".join(
|
||||
commit_message.split("\n")[1:]
|
||||
)
|
||||
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/pulls/{pr_number}/merge",
|
||||
json=data,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
# Check if PR was actually merged
|
||||
pr_result = await self.get_pr(owner, repo, pr_number)
|
||||
if pr_result.success and pr_result.pr:
|
||||
if pr_result.pr.get("state") == "merged":
|
||||
return MergePRResult(
|
||||
success=True,
|
||||
branch_deleted=delete_branch,
|
||||
)
|
||||
|
||||
return MergePRResult(
|
||||
success=False,
|
||||
error="Failed to merge pull request",
|
||||
)
|
||||
|
||||
return MergePRResult(
|
||||
success=True,
|
||||
merge_commit_sha=result.get("sha"),
|
||||
branch_deleted=delete_branch,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
return MergePRResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def update_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
title: str | None = None,
|
||||
body: str | None = None,
|
||||
state: PRState | None = None,
|
||||
labels: list[str] | None = None,
|
||||
assignees: list[str] | None = None,
|
||||
) -> UpdatePRResult:
|
||||
"""Update a pull request."""
|
||||
try:
|
||||
data: dict[str, Any] = {}
|
||||
|
||||
if title is not None:
|
||||
data["title"] = title
|
||||
if body is not None:
|
||||
data["body"] = body
|
||||
if state is not None:
|
||||
if state == PRState.OPEN:
|
||||
data["state"] = "open"
|
||||
elif state == PRState.CLOSED:
|
||||
data["state"] = "closed"
|
||||
|
||||
# Update PR if there's data
|
||||
if data:
|
||||
await self._request(
|
||||
"PATCH",
|
||||
f"/repos/{owner}/{repo}/pulls/{pr_number}",
|
||||
json=data,
|
||||
)
|
||||
|
||||
# Update labels via issue endpoint
|
||||
if labels is not None:
|
||||
# First clear existing labels
|
||||
await self._request(
|
||||
"DELETE",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
|
||||
)
|
||||
# Then add new labels
|
||||
if labels:
|
||||
await self.add_labels(owner, repo, pr_number, labels)
|
||||
|
||||
# Update assignees via issue endpoint
|
||||
if assignees is not None:
|
||||
await self._request(
|
||||
"PATCH",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}",
|
||||
json={"assignees": assignees},
|
||||
)
|
||||
|
||||
# Fetch updated PR
|
||||
result = await self.get_pr(owner, repo, pr_number)
|
||||
return UpdatePRResult(
|
||||
success=result.success,
|
||||
pr=result.pr,
|
||||
error=result.error,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
return UpdatePRResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def close_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
) -> UpdatePRResult:
|
||||
"""Close a pull request without merging."""
|
||||
return await self.update_pr(
|
||||
owner,
|
||||
repo,
|
||||
pr_number,
|
||||
state=PRState.CLOSED,
|
||||
)
|
||||
|
||||
# Branch operations
|
||||
|
||||
async def delete_remote_branch(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
branch: str,
|
||||
) -> bool:
|
||||
"""Delete a remote branch."""
|
||||
try:
|
||||
await self._request(
|
||||
"DELETE",
|
||||
f"/repos/{owner}/{repo}/branches/{branch}",
|
||||
)
|
||||
return True
|
||||
except APIError:
|
||||
return False
|
||||
|
||||
async def get_branch(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
branch: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Get branch information."""
|
||||
return await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/branches/{branch}",
|
||||
)
|
||||
|
||||
# Comment operations
|
||||
|
||||
async def add_pr_comment(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
body: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Add a comment to a pull request."""
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
|
||||
json={"body": body},
|
||||
)
|
||||
return result or {}
|
||||
|
||||
async def list_pr_comments(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List comments on a pull request."""
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
|
||||
)
|
||||
return result or []
|
||||
|
||||
# Label operations
|
||||
|
||||
async def add_labels(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
labels: list[str],
|
||||
) -> list[str]:
|
||||
"""Add labels to a pull request."""
|
||||
# First, get or create label IDs
|
||||
label_ids = []
|
||||
for label_name in labels:
|
||||
label_id = await self._get_or_create_label(owner, repo, label_name)
|
||||
if label_id:
|
||||
label_ids.append(label_id)
|
||||
|
||||
if label_ids:
|
||||
await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
|
||||
json={"labels": label_ids},
|
||||
)
|
||||
|
||||
# Return current labels
|
||||
issue = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}",
|
||||
)
|
||||
if issue:
|
||||
return [lbl["name"] for lbl in issue.get("labels", [])]
|
||||
return labels
|
||||
|
||||
async def remove_label(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
label: str,
|
||||
) -> list[str]:
|
||||
"""Remove a label from a pull request."""
|
||||
# Get label ID
|
||||
label_info = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/labels?name={label}",
|
||||
)
|
||||
|
||||
if label_info and len(label_info) > 0:
|
||||
label_id = label_info[0]["id"]
|
||||
await self._request(
|
||||
"DELETE",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/labels/{label_id}",
|
||||
)
|
||||
|
||||
# Return remaining labels
|
||||
issue = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}",
|
||||
)
|
||||
if issue:
|
||||
return [lbl["name"] for lbl in issue.get("labels", [])]
|
||||
return []
|
||||
|
||||
async def _get_or_create_label(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
label_name: str,
|
||||
) -> int | None:
|
||||
"""Get or create a label and return its ID."""
|
||||
# Try to find existing label
|
||||
labels = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/labels",
|
||||
)
|
||||
|
||||
if labels:
|
||||
for label in labels:
|
||||
if label["name"].lower() == label_name.lower():
|
||||
return label["id"]
|
||||
|
||||
# Create new label with default color
|
||||
try:
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/labels",
|
||||
json={
|
||||
"name": label_name,
|
||||
"color": "#3B82F6", # Default blue
|
||||
},
|
||||
)
|
||||
if result:
|
||||
return result["id"]
|
||||
except APIError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
# Reviewer operations
|
||||
|
||||
async def request_review(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
reviewers: list[str],
|
||||
) -> list[str]:
|
||||
"""Request review from users."""
|
||||
await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/pulls/{pr_number}/requested_reviewers",
|
||||
json={"reviewers": reviewers},
|
||||
)
|
||||
return reviewers
|
||||
|
||||
# Helper methods
|
||||
|
||||
def _parse_pr(self, data: dict[str, Any]) -> PRInfo:
|
||||
"""Parse PR API response into PRInfo."""
|
||||
# Parse dates
|
||||
created_at = self._parse_datetime(data.get("created_at"))
|
||||
updated_at = self._parse_datetime(data.get("updated_at"))
|
||||
merged_at = self._parse_datetime(data.get("merged_at"))
|
||||
closed_at = self._parse_datetime(data.get("closed_at"))
|
||||
|
||||
# Determine state
|
||||
if data.get("merged"):
|
||||
state = PRState.MERGED
|
||||
elif data.get("state") == "closed":
|
||||
state = PRState.CLOSED
|
||||
else:
|
||||
state = PRState.OPEN
|
||||
|
||||
# Extract labels
|
||||
labels = [lbl["name"] for lbl in data.get("labels", [])]
|
||||
|
||||
# Extract assignees
|
||||
assignees = [a["login"] for a in data.get("assignees", [])]
|
||||
|
||||
# Extract reviewers
|
||||
reviewers = []
|
||||
if "requested_reviewers" in data:
|
||||
reviewers = [r["login"] for r in data["requested_reviewers"]]
|
||||
|
||||
return PRInfo(
|
||||
number=data["number"],
|
||||
title=data["title"],
|
||||
body=data.get("body", ""),
|
||||
state=state,
|
||||
source_branch=data.get("head", {}).get("ref", ""),
|
||||
target_branch=data.get("base", {}).get("ref", ""),
|
||||
author=data.get("user", {}).get("login", ""),
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
merged_at=merged_at,
|
||||
closed_at=closed_at,
|
||||
url=data.get("html_url"),
|
||||
labels=labels,
|
||||
assignees=assignees,
|
||||
reviewers=reviewers,
|
||||
mergeable=data.get("mergeable"),
|
||||
draft=data.get("draft", False),
|
||||
)
|
||||
|
||||
def _parse_datetime(self, value: str | None) -> datetime:
|
||||
"""Parse datetime string from API."""
|
||||
if not value:
|
||||
return datetime.now(UTC)
|
||||
|
||||
try:
|
||||
# Handle Gitea's datetime format
|
||||
if value.endswith("Z"):
|
||||
value = value[:-1] + "+00:00"
|
||||
return datetime.fromisoformat(value)
|
||||
except ValueError:
|
||||
return datetime.now(UTC)
|
||||
118
mcp-servers/git-ops/pyproject.toml
Normal file
118
mcp-servers/git-ops/pyproject.toml
Normal file
@@ -0,0 +1,118 @@
|
||||
[project]
|
||||
name = "syndarix-mcp-git-ops"
|
||||
version = "0.1.0"
|
||||
description = "Syndarix Git Operations MCP Server - Repository management, branching, commits, and PR workflows"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastmcp>=2.0.0",
|
||||
"gitpython>=3.1.0",
|
||||
"httpx>=0.27.0",
|
||||
"redis>=5.0.0",
|
||||
"pydantic>=2.0.0",
|
||||
"pydantic-settings>=2.0.0",
|
||||
"uvicorn>=0.30.0",
|
||||
"fastapi>=0.115.0",
|
||||
"filelock>=3.15.0",
|
||||
"aiofiles>=24.1.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"pytest-cov>=5.0.0",
|
||||
"fakeredis>=2.25.0",
|
||||
"ruff>=0.8.0",
|
||||
"mypy>=1.11.0",
|
||||
"respx>=0.21.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
git-ops = "server:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["."]
|
||||
exclude = ["tests/", "*.md", "Dockerfile"]
|
||||
|
||||
[tool.hatch.build.targets.sdist]
|
||||
include = ["*.py", "pyproject.toml"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
line-length = 88
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"UP", # pyupgrade
|
||||
"ARG", # flake8-unused-arguments
|
||||
"SIM", # flake8-simplify
|
||||
"S", # flake8-bandit (security)
|
||||
]
|
||||
ignore = [
|
||||
"E501", # line too long (handled by formatter)
|
||||
"B008", # do not perform function calls in argument defaults
|
||||
"B904", # raise from in except (too noisy)
|
||||
"S104", # possible binding to all interfaces
|
||||
"S110", # try-except-pass (intentional for optional operations)
|
||||
"S603", # subprocess without shell=True (safe usage in git wrapper)
|
||||
"S607", # starting a process with a partial path (git CLI)
|
||||
"ARG002", # unused method arguments (for API compatibility)
|
||||
"SIM102", # nested if statements (sometimes more readable)
|
||||
"SIM105", # contextlib.suppress (sometimes more readable)
|
||||
"SIM108", # ternary operator (sometimes more readable)
|
||||
"SIM118", # dict.keys() (explicit is fine)
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["config", "models", "exceptions", "git_wrapper", "workspace", "providers"]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/**/*.py" = ["S101", "ARG001", "S105", "S106", "S108", "F841", "B007"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
testpaths = ["tests"]
|
||||
addopts = "-v --tb=short"
|
||||
filterwarnings = [
|
||||
"ignore::DeprecationWarning",
|
||||
]
|
||||
|
||||
[tool.coverage.run]
|
||||
source = ["."]
|
||||
omit = ["tests/*", "conftest.py"]
|
||||
branch = true
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_lines = [
|
||||
"pragma: no cover",
|
||||
"def __repr__",
|
||||
"raise NotImplementedError",
|
||||
"if TYPE_CHECKING:",
|
||||
"if __name__ == .__main__.:",
|
||||
]
|
||||
fail_under = 65 # TODO: Increase to 80% once more tool tests are added
|
||||
show_missing = true
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.12"
|
||||
warn_return_any = false
|
||||
warn_unused_ignores = false
|
||||
disallow_untyped_defs = true
|
||||
ignore_missing_imports = true
|
||||
plugins = ["pydantic.mypy"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "tests.*"
|
||||
disallow_untyped_defs = false
|
||||
ignore_errors = true
|
||||
1226
mcp-servers/git-ops/server.py
Normal file
1226
mcp-servers/git-ops/server.py
Normal file
File diff suppressed because it is too large
Load Diff
1
mcp-servers/git-ops/tests/__init__.py
Normal file
1
mcp-servers/git-ops/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for Git Operations MCP Server."""
|
||||
299
mcp-servers/git-ops/tests/conftest.py
Normal file
299
mcp-servers/git-ops/tests/conftest.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
Test configuration and fixtures for Git Operations MCP Server.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from git import Repo as GitRepo
|
||||
|
||||
# Set test environment
|
||||
os.environ["IS_TEST"] = "true"
|
||||
os.environ["GIT_OPS_WORKSPACE_BASE_PATH"] = "/tmp/test-workspaces"
|
||||
os.environ["GIT_OPS_GITEA_BASE_URL"] = "https://gitea.test.com"
|
||||
os.environ["GIT_OPS_GITEA_TOKEN"] = "test-token"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def reset_settings_session():
|
||||
"""Reset settings at start and end of test session."""
|
||||
from config import reset_settings
|
||||
|
||||
reset_settings()
|
||||
yield
|
||||
reset_settings()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_settings():
|
||||
"""Reset settings before each test that needs it."""
|
||||
from config import reset_settings
|
||||
|
||||
reset_settings()
|
||||
yield
|
||||
reset_settings()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_settings():
|
||||
"""Get test settings."""
|
||||
from config import Settings
|
||||
|
||||
return Settings(
|
||||
workspace_base_path=Path("/tmp/test-workspaces"),
|
||||
gitea_base_url="https://gitea.test.com",
|
||||
gitea_token="test-token",
|
||||
github_token="github-test-token",
|
||||
git_author_name="Test Agent",
|
||||
git_author_email="test@syndarix.ai",
|
||||
enable_force_push=False,
|
||||
debug=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir() -> Iterator[Path]:
|
||||
"""Create a temporary directory for tests."""
|
||||
temp_path = Path(tempfile.mkdtemp())
|
||||
yield temp_path
|
||||
if temp_path.exists():
|
||||
shutil.rmtree(temp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_workspace(temp_dir: Path) -> Path:
|
||||
"""Create a temporary workspace directory."""
|
||||
workspace = temp_dir / "workspace"
|
||||
workspace.mkdir(parents=True, exist_ok=True)
|
||||
return workspace
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def git_repo(temp_workspace: Path) -> GitRepo:
|
||||
"""Create a git repository in the temp workspace."""
|
||||
# Initialize with main branch (Git 2.28+)
|
||||
repo = GitRepo.init(temp_workspace, initial_branch="main")
|
||||
|
||||
# Configure git
|
||||
with repo.config_writer() as cw:
|
||||
cw.set_value("user", "name", "Test User")
|
||||
cw.set_value("user", "email", "test@example.com")
|
||||
|
||||
# Create initial commit
|
||||
test_file = temp_workspace / "README.md"
|
||||
test_file.write_text("# Test Repository\n")
|
||||
repo.index.add(["README.md"])
|
||||
repo.index.commit("Initial commit")
|
||||
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def git_repo_with_remote(git_repo: GitRepo, temp_dir: Path) -> tuple[GitRepo, GitRepo]:
|
||||
"""Create a git repository with a 'remote' (bare repo)."""
|
||||
# Create bare repo as remote
|
||||
remote_path = temp_dir / "remote.git"
|
||||
remote_repo = GitRepo.init(remote_path, bare=True)
|
||||
|
||||
# Add remote to main repo
|
||||
git_repo.create_remote("origin", str(remote_path))
|
||||
|
||||
# Push initial commit
|
||||
git_repo.remotes.origin.push("main:main")
|
||||
|
||||
# Set up tracking
|
||||
git_repo.heads.main.set_tracking_branch(git_repo.remotes.origin.refs.main)
|
||||
|
||||
return git_repo, remote_repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workspace_manager(temp_dir: Path, test_settings):
|
||||
"""Create a WorkspaceManager with test settings."""
|
||||
from workspace import WorkspaceManager
|
||||
|
||||
test_settings.workspace_base_path = temp_dir / "workspaces"
|
||||
return WorkspaceManager(test_settings)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def git_wrapper(temp_workspace: Path, test_settings):
|
||||
"""Create a GitWrapper for the temp workspace."""
|
||||
from git_wrapper import GitWrapper
|
||||
|
||||
return GitWrapper(temp_workspace, test_settings)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def git_wrapper_with_repo(git_repo: GitRepo, test_settings):
|
||||
"""Create a GitWrapper for a repo that's already initialized."""
|
||||
from git_wrapper import GitWrapper
|
||||
|
||||
return GitWrapper(Path(git_repo.working_dir), test_settings)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gitea_provider():
|
||||
"""Create a mock Gitea provider."""
|
||||
provider = AsyncMock()
|
||||
provider.name = "gitea"
|
||||
provider.is_connected = AsyncMock(return_value=True)
|
||||
provider.get_authenticated_user = AsyncMock(return_value="test-user")
|
||||
provider.parse_repo_url = MagicMock(return_value=("owner", "repo"))
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_httpx_client():
|
||||
"""Create a mock httpx client for provider tests."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = MagicMock(return_value={})
|
||||
mock_response.text = ""
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request = AsyncMock(return_value=mock_response)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.patch = AsyncMock(return_value=mock_response)
|
||||
mock_client.delete = AsyncMock(return_value=mock_response)
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def gitea_provider(test_settings, mock_httpx_client):
|
||||
"""Create a GiteaProvider with mocked HTTP client."""
|
||||
from providers.gitea import GiteaProvider
|
||||
|
||||
provider = GiteaProvider(
|
||||
base_url=test_settings.gitea_base_url,
|
||||
token=test_settings.gitea_token,
|
||||
settings=test_settings,
|
||||
)
|
||||
provider._client = mock_httpx_client
|
||||
|
||||
yield provider
|
||||
|
||||
await provider.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_pr_data():
|
||||
"""Sample PR data from Gitea API."""
|
||||
return {
|
||||
"number": 42,
|
||||
"title": "Test PR",
|
||||
"body": "This is a test pull request",
|
||||
"state": "open",
|
||||
"head": {"ref": "feature-branch"},
|
||||
"base": {"ref": "main"},
|
||||
"user": {"login": "test-user"},
|
||||
"created_at": "2024-01-15T10:00:00Z",
|
||||
"updated_at": "2024-01-15T12:00:00Z",
|
||||
"merged_at": None,
|
||||
"closed_at": None,
|
||||
"html_url": "https://gitea.test.com/owner/repo/pull/42",
|
||||
"labels": [{"name": "enhancement"}],
|
||||
"assignees": [{"login": "assignee1"}],
|
||||
"requested_reviewers": [{"login": "reviewer1"}],
|
||||
"mergeable": True,
|
||||
"draft": False,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_commit_data():
|
||||
"""Sample commit data."""
|
||||
return {
|
||||
"sha": "abc123def456",
|
||||
"short_sha": "abc123d",
|
||||
"message": "Test commit message",
|
||||
"author": {
|
||||
"name": "Test Author",
|
||||
"email": "author@test.com",
|
||||
"date": "2024-01-15T10:00:00Z",
|
||||
},
|
||||
"committer": {
|
||||
"name": "Test Committer",
|
||||
"email": "committer@test.com",
|
||||
"date": "2024-01-15T10:00:00Z",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_fastapi_app():
|
||||
"""Create a test FastAPI app."""
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "healthy"}
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
# Async fixtures
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_workspace_manager(
|
||||
temp_dir: Path, test_settings
|
||||
) -> AsyncIterator:
|
||||
"""Async fixture for workspace manager."""
|
||||
from workspace import WorkspaceManager
|
||||
|
||||
test_settings.workspace_base_path = temp_dir / "workspaces"
|
||||
manager = WorkspaceManager(test_settings)
|
||||
yield manager
|
||||
|
||||
|
||||
# Test data fixtures
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_project_id() -> str:
|
||||
"""Valid project ID for tests."""
|
||||
return "test-project-123"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_agent_id() -> str:
|
||||
"""Valid agent ID for tests."""
|
||||
return "agent-456"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_ids() -> list[str]:
|
||||
"""Invalid IDs for validation tests."""
|
||||
return [
|
||||
"",
|
||||
" ",
|
||||
"a" * 200, # Too long
|
||||
"test@invalid", # Invalid character
|
||||
"test!invalid",
|
||||
"../path/traversal",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_repo_url() -> str:
|
||||
"""Sample repository URL."""
|
||||
return "https://gitea.test.com/owner/repo.git"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ssh_repo_url() -> str:
|
||||
"""Sample SSH repository URL."""
|
||||
return "git@gitea.test.com:owner/repo.git"
|
||||
434
mcp-servers/git-ops/tests/test_git_wrapper.py
Normal file
434
mcp-servers/git-ops/tests/test_git_wrapper.py
Normal file
@@ -0,0 +1,434 @@
|
||||
"""
|
||||
Tests for the git_wrapper module.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from exceptions import (
|
||||
BranchExistsError,
|
||||
BranchNotFoundError,
|
||||
CheckoutError,
|
||||
CommitError,
|
||||
GitError,
|
||||
)
|
||||
from git_wrapper import GitWrapper
|
||||
from models import FileChangeType
|
||||
|
||||
|
||||
class TestGitWrapperInit:
|
||||
"""Tests for GitWrapper initialization."""
|
||||
|
||||
def test_init_with_valid_path(self, temp_workspace, test_settings):
|
||||
"""Test initialization with a valid path."""
|
||||
wrapper = GitWrapper(temp_workspace, test_settings)
|
||||
assert wrapper.workspace_path == temp_workspace
|
||||
assert wrapper.settings == test_settings
|
||||
|
||||
def test_repo_property_raises_on_non_git(self, temp_workspace, test_settings):
|
||||
"""Test that accessing repo on non-git dir raises error."""
|
||||
wrapper = GitWrapper(temp_workspace, test_settings)
|
||||
with pytest.raises(GitError, match="Not a git repository"):
|
||||
_ = wrapper.repo
|
||||
|
||||
def test_repo_property_works_on_git_dir(self, git_repo, test_settings):
|
||||
"""Test that repo property works for git directory."""
|
||||
wrapper = GitWrapper(Path(git_repo.working_dir), test_settings)
|
||||
assert wrapper.repo is not None
|
||||
assert wrapper.repo.head is not None
|
||||
|
||||
|
||||
class TestGitWrapperStatus:
|
||||
"""Tests for git status operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_clean_repo(self, git_wrapper_with_repo):
|
||||
"""Test status on a clean repository."""
|
||||
result = await git_wrapper_with_repo.status()
|
||||
|
||||
assert result.branch == "main"
|
||||
assert result.is_clean is True
|
||||
assert len(result.staged) == 0
|
||||
assert len(result.unstaged) == 0
|
||||
assert len(result.untracked) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_with_untracked(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test status with untracked files."""
|
||||
# Create untracked file
|
||||
untracked_file = Path(git_repo.working_dir) / "untracked.txt"
|
||||
untracked_file.write_text("untracked content")
|
||||
|
||||
result = await git_wrapper_with_repo.status()
|
||||
|
||||
assert result.is_clean is False
|
||||
assert "untracked.txt" in result.untracked
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_with_modified(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test status with modified files."""
|
||||
# Modify existing file
|
||||
readme = Path(git_repo.working_dir) / "README.md"
|
||||
readme.write_text("# Modified content\n")
|
||||
|
||||
result = await git_wrapper_with_repo.status()
|
||||
|
||||
assert result.is_clean is False
|
||||
assert len(result.unstaged) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_with_staged(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test status with staged changes."""
|
||||
# Create and stage a file
|
||||
new_file = Path(git_repo.working_dir) / "staged.txt"
|
||||
new_file.write_text("staged content")
|
||||
git_repo.index.add(["staged.txt"])
|
||||
|
||||
result = await git_wrapper_with_repo.status()
|
||||
|
||||
assert result.is_clean is False
|
||||
assert len(result.staged) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_exclude_untracked(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test status without untracked files."""
|
||||
untracked_file = Path(git_repo.working_dir) / "untracked.txt"
|
||||
untracked_file.write_text("untracked")
|
||||
|
||||
result = await git_wrapper_with_repo.status(include_untracked=False)
|
||||
|
||||
assert len(result.untracked) == 0
|
||||
|
||||
|
||||
class TestGitWrapperBranch:
|
||||
"""Tests for branch operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_branch(self, git_wrapper_with_repo):
|
||||
"""Test creating a new branch."""
|
||||
result = await git_wrapper_with_repo.create_branch("feature-test")
|
||||
|
||||
assert result.success is True
|
||||
assert result.branch == "feature-test"
|
||||
assert result.is_current is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_branch_without_checkout(self, git_wrapper_with_repo):
|
||||
"""Test creating branch without checkout."""
|
||||
result = await git_wrapper_with_repo.create_branch("feature-no-checkout", checkout=False)
|
||||
|
||||
assert result.success is True
|
||||
assert result.branch == "feature-no-checkout"
|
||||
assert result.is_current is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_branch_exists_error(self, git_wrapper_with_repo):
|
||||
"""Test error when branch already exists."""
|
||||
await git_wrapper_with_repo.create_branch("existing-branch", checkout=False)
|
||||
|
||||
with pytest.raises(BranchExistsError):
|
||||
await git_wrapper_with_repo.create_branch("existing-branch")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_branch(self, git_wrapper_with_repo):
|
||||
"""Test deleting a branch."""
|
||||
# Create branch first
|
||||
await git_wrapper_with_repo.create_branch("to-delete", checkout=False)
|
||||
|
||||
# Delete it
|
||||
result = await git_wrapper_with_repo.delete_branch("to-delete")
|
||||
|
||||
assert result.success is True
|
||||
assert result.branch == "to-delete"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_branch_not_found(self, git_wrapper_with_repo):
|
||||
"""Test error when deleting non-existent branch."""
|
||||
with pytest.raises(BranchNotFoundError):
|
||||
await git_wrapper_with_repo.delete_branch("nonexistent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_current_branch_error(self, git_wrapper_with_repo):
|
||||
"""Test error when deleting current branch."""
|
||||
with pytest.raises(GitError, match="Cannot delete current branch"):
|
||||
await git_wrapper_with_repo.delete_branch("main")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_branches(self, git_wrapper_with_repo):
|
||||
"""Test listing branches."""
|
||||
# Create some branches
|
||||
await git_wrapper_with_repo.create_branch("branch-a", checkout=False)
|
||||
await git_wrapper_with_repo.create_branch("branch-b", checkout=False)
|
||||
|
||||
result = await git_wrapper_with_repo.list_branches()
|
||||
|
||||
assert result.current_branch == "main"
|
||||
branch_names = [b["name"] for b in result.local_branches]
|
||||
assert "main" in branch_names
|
||||
assert "branch-a" in branch_names
|
||||
assert "branch-b" in branch_names
|
||||
|
||||
|
||||
class TestGitWrapperCheckout:
|
||||
"""Tests for checkout operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkout_existing_branch(self, git_wrapper_with_repo):
|
||||
"""Test checkout of existing branch."""
|
||||
# Create branch first
|
||||
await git_wrapper_with_repo.create_branch("test-branch", checkout=False)
|
||||
|
||||
result = await git_wrapper_with_repo.checkout("test-branch")
|
||||
|
||||
assert result.success is True
|
||||
assert result.ref == "test-branch"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkout_create_new(self, git_wrapper_with_repo):
|
||||
"""Test checkout with branch creation."""
|
||||
result = await git_wrapper_with_repo.checkout("new-branch", create_branch=True)
|
||||
|
||||
assert result.success is True
|
||||
assert result.ref == "new-branch"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkout_nonexistent_error(self, git_wrapper_with_repo):
|
||||
"""Test error when checking out non-existent ref."""
|
||||
with pytest.raises(CheckoutError):
|
||||
await git_wrapper_with_repo.checkout("nonexistent-branch")
|
||||
|
||||
|
||||
class TestGitWrapperCommit:
|
||||
"""Tests for commit operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commit_staged_changes(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test committing staged changes."""
|
||||
# Create and stage a file
|
||||
new_file = Path(git_repo.working_dir) / "newfile.txt"
|
||||
new_file.write_text("new content")
|
||||
git_repo.index.add(["newfile.txt"])
|
||||
|
||||
result = await git_wrapper_with_repo.commit("Add new file")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message == "Add new file"
|
||||
assert result.files_changed == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commit_all_changes(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test committing all changes (auto-stage)."""
|
||||
# Create a file without staging
|
||||
new_file = Path(git_repo.working_dir) / "unstaged.txt"
|
||||
new_file.write_text("content")
|
||||
|
||||
result = await git_wrapper_with_repo.commit("Commit unstaged")
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commit_nothing_to_commit(self, git_wrapper_with_repo):
|
||||
"""Test error when nothing to commit."""
|
||||
with pytest.raises(CommitError, match="Nothing to commit"):
|
||||
await git_wrapper_with_repo.commit("Empty commit")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commit_with_author(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test commit with custom author."""
|
||||
new_file = Path(git_repo.working_dir) / "authored.txt"
|
||||
new_file.write_text("authored content")
|
||||
|
||||
result = await git_wrapper_with_repo.commit(
|
||||
"Custom author commit",
|
||||
author_name="Custom Author",
|
||||
author_email="custom@test.com",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
|
||||
class TestGitWrapperDiff:
|
||||
"""Tests for diff operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diff_no_changes(self, git_wrapper_with_repo):
|
||||
"""Test diff with no changes."""
|
||||
result = await git_wrapper_with_repo.diff()
|
||||
|
||||
assert result.files_changed == 0
|
||||
assert result.total_additions == 0
|
||||
assert result.total_deletions == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diff_with_changes(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test diff with modified files."""
|
||||
# Modify a file
|
||||
readme = Path(git_repo.working_dir) / "README.md"
|
||||
readme.write_text("# Modified\nNew line\n")
|
||||
|
||||
result = await git_wrapper_with_repo.diff()
|
||||
|
||||
assert result.files_changed > 0
|
||||
|
||||
|
||||
class TestGitWrapperLog:
|
||||
"""Tests for log operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_basic(self, git_wrapper_with_repo):
|
||||
"""Test basic log."""
|
||||
result = await git_wrapper_with_repo.log()
|
||||
|
||||
assert result.total_commits > 0
|
||||
assert len(result.commits) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_with_limit(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test log with limit."""
|
||||
# Create more commits
|
||||
for i in range(5):
|
||||
file_path = Path(git_repo.working_dir) / f"file{i}.txt"
|
||||
file_path.write_text(f"content {i}")
|
||||
git_repo.index.add([f"file{i}.txt"])
|
||||
git_repo.index.commit(f"Commit {i}")
|
||||
|
||||
result = await git_wrapper_with_repo.log(limit=3)
|
||||
|
||||
assert len(result.commits) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_commit_info(self, git_wrapper_with_repo):
|
||||
"""Test that log returns proper commit info."""
|
||||
result = await git_wrapper_with_repo.log(limit=1)
|
||||
|
||||
commit = result.commits[0]
|
||||
assert "sha" in commit
|
||||
assert "message" in commit
|
||||
assert "author_name" in commit
|
||||
assert "author_email" in commit
|
||||
|
||||
|
||||
class TestGitWrapperUtilities:
|
||||
"""Tests for utility methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_valid_ref_true(self, git_wrapper_with_repo):
|
||||
"""Test valid ref detection."""
|
||||
is_valid = await git_wrapper_with_repo.is_valid_ref("main")
|
||||
assert is_valid is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_valid_ref_false(self, git_wrapper_with_repo):
|
||||
"""Test invalid ref detection."""
|
||||
is_valid = await git_wrapper_with_repo.is_valid_ref("nonexistent")
|
||||
assert is_valid is False
|
||||
|
||||
def test_diff_to_change_type(self, git_wrapper_with_repo):
|
||||
"""Test change type conversion."""
|
||||
wrapper = git_wrapper_with_repo
|
||||
|
||||
assert wrapper._diff_to_change_type("A") == FileChangeType.ADDED
|
||||
assert wrapper._diff_to_change_type("M") == FileChangeType.MODIFIED
|
||||
assert wrapper._diff_to_change_type("D") == FileChangeType.DELETED
|
||||
assert wrapper._diff_to_change_type("R") == FileChangeType.RENAMED
|
||||
|
||||
|
||||
class TestGitWrapperStage:
|
||||
"""Tests for staging operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stage_specific_files(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test staging specific files."""
|
||||
# Create files
|
||||
file1 = Path(git_repo.working_dir) / "file1.txt"
|
||||
file2 = Path(git_repo.working_dir) / "file2.txt"
|
||||
file1.write_text("content 1")
|
||||
file2.write_text("content 2")
|
||||
|
||||
count = await git_wrapper_with_repo.stage(["file1.txt"])
|
||||
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stage_all(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test staging all files."""
|
||||
file1 = Path(git_repo.working_dir) / "all1.txt"
|
||||
file2 = Path(git_repo.working_dir) / "all2.txt"
|
||||
file1.write_text("content 1")
|
||||
file2.write_text("content 2")
|
||||
|
||||
count = await git_wrapper_with_repo.stage()
|
||||
|
||||
assert count >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unstage_files(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test unstaging files."""
|
||||
# Create and stage file
|
||||
file1 = Path(git_repo.working_dir) / "unstage.txt"
|
||||
file1.write_text("to unstage")
|
||||
git_repo.index.add(["unstage.txt"])
|
||||
|
||||
count = await git_wrapper_with_repo.unstage()
|
||||
|
||||
assert count >= 1
|
||||
|
||||
|
||||
class TestGitWrapperReset:
|
||||
"""Tests for reset operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_soft(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test soft reset."""
|
||||
# Create a commit to reset
|
||||
file1 = Path(git_repo.working_dir) / "reset_soft.txt"
|
||||
file1.write_text("content")
|
||||
git_repo.index.add(["reset_soft.txt"])
|
||||
git_repo.index.commit("Commit to reset")
|
||||
|
||||
result = await git_wrapper_with_repo.reset("HEAD~1", mode="soft")
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_mixed(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test mixed reset (default)."""
|
||||
file1 = Path(git_repo.working_dir) / "reset_mixed.txt"
|
||||
file1.write_text("content")
|
||||
git_repo.index.add(["reset_mixed.txt"])
|
||||
git_repo.index.commit("Commit to reset")
|
||||
|
||||
result = await git_wrapper_with_repo.reset("HEAD~1", mode="mixed")
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_invalid_mode(self, git_wrapper_with_repo):
|
||||
"""Test error on invalid reset mode."""
|
||||
with pytest.raises(GitError, match="Invalid reset mode"):
|
||||
await git_wrapper_with_repo.reset("HEAD", mode="invalid")
|
||||
|
||||
|
||||
class TestGitWrapperStash:
|
||||
"""Tests for stash operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stash_changes(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test stashing changes."""
|
||||
# Make changes
|
||||
readme = Path(git_repo.working_dir) / "README.md"
|
||||
readme.write_text("Modified for stash")
|
||||
|
||||
result = await git_wrapper_with_repo.stash("Test stash")
|
||||
|
||||
# Result should be stash ref or None if nothing to stash
|
||||
# (depends on whether changes were already staged)
|
||||
assert result is None or result.startswith("stash@")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stash_nothing(self, git_wrapper_with_repo):
|
||||
"""Test stash with no changes."""
|
||||
result = await git_wrapper_with_repo.stash()
|
||||
|
||||
assert result is None
|
||||
484
mcp-servers/git-ops/tests/test_providers.py
Normal file
484
mcp-servers/git-ops/tests/test_providers.py
Normal file
@@ -0,0 +1,484 @@
|
||||
"""
|
||||
Tests for git provider implementations.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from exceptions import APIError, AuthenticationError
|
||||
from models import MergeStrategy, PRState
|
||||
from providers.gitea import GiteaProvider
|
||||
|
||||
|
||||
class TestBaseProvider:
|
||||
"""Tests for BaseProvider interface."""
|
||||
|
||||
def test_parse_repo_url_https(self, mock_gitea_provider):
|
||||
"""Test parsing HTTPS repo URL."""
|
||||
# The mock needs parse_repo_url to work
|
||||
provider = GiteaProvider(
|
||||
base_url="https://gitea.test.com",
|
||||
token="test-token"
|
||||
)
|
||||
|
||||
owner, repo = provider.parse_repo_url("https://gitea.test.com/owner/repo.git")
|
||||
|
||||
assert owner == "owner"
|
||||
assert repo == "repo"
|
||||
|
||||
def test_parse_repo_url_https_no_git(self):
|
||||
"""Test parsing HTTPS URL without .git suffix."""
|
||||
provider = GiteaProvider(
|
||||
base_url="https://gitea.test.com",
|
||||
token="test-token"
|
||||
)
|
||||
|
||||
owner, repo = provider.parse_repo_url("https://gitea.test.com/owner/repo")
|
||||
|
||||
assert owner == "owner"
|
||||
assert repo == "repo"
|
||||
|
||||
def test_parse_repo_url_ssh(self):
|
||||
"""Test parsing SSH repo URL."""
|
||||
provider = GiteaProvider(
|
||||
base_url="https://gitea.test.com",
|
||||
token="test-token"
|
||||
)
|
||||
|
||||
owner, repo = provider.parse_repo_url("git@gitea.test.com:owner/repo.git")
|
||||
|
||||
assert owner == "owner"
|
||||
assert repo == "repo"
|
||||
|
||||
def test_parse_repo_url_invalid(self):
|
||||
"""Test error on invalid URL."""
|
||||
provider = GiteaProvider(
|
||||
base_url="https://gitea.test.com",
|
||||
token="test-token"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unable to parse"):
|
||||
provider.parse_repo_url("invalid-url")
|
||||
|
||||
|
||||
class TestGiteaProvider:
|
||||
"""Tests for GiteaProvider."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_connected(self, gitea_provider, mock_httpx_client):
|
||||
"""Test connection check."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"login": "test-user"}
|
||||
)
|
||||
|
||||
result = await gitea_provider.is_connected()
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_connected_no_token(self, test_settings):
|
||||
"""Test connection fails without token."""
|
||||
provider = GiteaProvider(
|
||||
base_url="https://gitea.test.com",
|
||||
token="",
|
||||
settings=test_settings,
|
||||
)
|
||||
|
||||
result = await provider.is_connected()
|
||||
assert result is False
|
||||
|
||||
await provider.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_authenticated_user(self, gitea_provider, mock_httpx_client):
|
||||
"""Test getting authenticated user."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"login": "test-user"}
|
||||
)
|
||||
|
||||
user = await gitea_provider.get_authenticated_user()
|
||||
|
||||
assert user == "test-user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_repo_info(self, gitea_provider, mock_httpx_client):
|
||||
"""Test getting repository info."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={
|
||||
"name": "repo",
|
||||
"full_name": "owner/repo",
|
||||
"default_branch": "main",
|
||||
}
|
||||
)
|
||||
|
||||
result = await gitea_provider.get_repo_info("owner", "repo")
|
||||
|
||||
assert result["name"] == "repo"
|
||||
assert result["default_branch"] == "main"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_default_branch(self, gitea_provider, mock_httpx_client):
|
||||
"""Test getting default branch."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"default_branch": "develop"}
|
||||
)
|
||||
|
||||
branch = await gitea_provider.get_default_branch("owner", "repo")
|
||||
|
||||
assert branch == "develop"
|
||||
|
||||
|
||||
class TestGiteaPROperations:
|
||||
"""Tests for Gitea PR operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pr(self, gitea_provider, mock_httpx_client):
|
||||
"""Test creating a pull request."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={
|
||||
"number": 42,
|
||||
"html_url": "https://gitea.test.com/owner/repo/pull/42",
|
||||
}
|
||||
)
|
||||
|
||||
result = await gitea_provider.create_pr(
|
||||
owner="owner",
|
||||
repo="repo",
|
||||
title="Test PR",
|
||||
body="Test body",
|
||||
source_branch="feature",
|
||||
target_branch="main",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.pr_number == 42
|
||||
assert result.pr_url == "https://gitea.test.com/owner/repo/pull/42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pr_with_options(self, gitea_provider, mock_httpx_client):
|
||||
"""Test creating PR with labels, assignees, reviewers."""
|
||||
# Use side_effect for multiple API calls:
|
||||
# 1. POST create PR
|
||||
# 2. GET labels (for "enhancement") - in add_labels -> _get_or_create_label
|
||||
# 3. POST add labels to PR - in add_labels
|
||||
# 4. GET issue to return labels - in add_labels
|
||||
# 5. PATCH add assignees
|
||||
# 6. POST request reviewers
|
||||
mock_responses = [
|
||||
{"number": 43, "html_url": "https://gitea.test.com/owner/repo/pull/43"}, # Create PR
|
||||
[{"id": 1, "name": "enhancement"}], # GET labels (found)
|
||||
{}, # POST add labels to PR
|
||||
{"labels": [{"name": "enhancement"}]}, # GET issue to return current labels
|
||||
{}, # PATCH add assignees
|
||||
{}, # POST request reviewers
|
||||
]
|
||||
mock_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses)
|
||||
|
||||
result = await gitea_provider.create_pr(
|
||||
owner="owner",
|
||||
repo="repo",
|
||||
title="Test PR",
|
||||
body="Test body",
|
||||
source_branch="feature",
|
||||
target_branch="main",
|
||||
labels=["enhancement"],
|
||||
assignees=["user1"],
|
||||
reviewers=["reviewer1"],
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pr(self, gitea_provider, mock_httpx_client, sample_pr_data):
|
||||
"""Test getting a pull request."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=sample_pr_data
|
||||
)
|
||||
|
||||
result = await gitea_provider.get_pr("owner", "repo", 42)
|
||||
|
||||
assert result.success is True
|
||||
assert result.pr["number"] == 42
|
||||
assert result.pr["title"] == "Test PR"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pr_not_found(self, gitea_provider, mock_httpx_client):
|
||||
"""Test getting non-existent PR."""
|
||||
mock_httpx_client.request.return_value.status_code = 404
|
||||
mock_httpx_client.request.return_value.json = MagicMock(return_value=None)
|
||||
|
||||
result = await gitea_provider.get_pr("owner", "repo", 999)
|
||||
|
||||
assert result.success is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_prs(self, gitea_provider, mock_httpx_client, sample_pr_data):
|
||||
"""Test listing pull requests."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=[sample_pr_data, sample_pr_data]
|
||||
)
|
||||
|
||||
result = await gitea_provider.list_prs("owner", "repo")
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.pull_requests) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_prs_with_state_filter(self, gitea_provider, mock_httpx_client, sample_pr_data):
|
||||
"""Test listing PRs with state filter."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=[sample_pr_data]
|
||||
)
|
||||
|
||||
result = await gitea_provider.list_prs(
|
||||
"owner", "repo", state=PRState.OPEN
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_pr(self, gitea_provider, mock_httpx_client):
|
||||
"""Test merging a pull request."""
|
||||
# First call returns merge result
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"sha": "merge-commit-sha"}
|
||||
)
|
||||
|
||||
result = await gitea_provider.merge_pr(
|
||||
"owner", "repo", 42,
|
||||
merge_strategy=MergeStrategy.SQUASH,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.merge_commit_sha == "merge-commit-sha"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_pr(self, gitea_provider, mock_httpx_client, sample_pr_data):
|
||||
"""Test updating a pull request."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=sample_pr_data
|
||||
)
|
||||
|
||||
result = await gitea_provider.update_pr(
|
||||
"owner", "repo", 42,
|
||||
title="Updated Title",
|
||||
body="Updated body",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_pr(self, gitea_provider, mock_httpx_client, sample_pr_data):
|
||||
"""Test closing a pull request."""
|
||||
sample_pr_data["state"] = "closed"
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=sample_pr_data
|
||||
)
|
||||
|
||||
result = await gitea_provider.close_pr("owner", "repo", 42)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
|
||||
class TestGiteaBranchOperations:
|
||||
"""Tests for Gitea branch operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_branch(self, gitea_provider, mock_httpx_client):
|
||||
"""Test getting branch info."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={
|
||||
"name": "main",
|
||||
"commit": {"sha": "abc123"},
|
||||
}
|
||||
)
|
||||
|
||||
result = await gitea_provider.get_branch("owner", "repo", "main")
|
||||
|
||||
assert result["name"] == "main"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_remote_branch(self, gitea_provider, mock_httpx_client):
|
||||
"""Test deleting a remote branch."""
|
||||
mock_httpx_client.request.return_value.status_code = 204
|
||||
|
||||
result = await gitea_provider.delete_remote_branch("owner", "repo", "old-branch")
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestGiteaCommentOperations:
|
||||
"""Tests for Gitea comment operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_pr_comment(self, gitea_provider, mock_httpx_client):
|
||||
"""Test adding a comment to a PR."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"id": 1, "body": "Test comment"}
|
||||
)
|
||||
|
||||
result = await gitea_provider.add_pr_comment(
|
||||
"owner", "repo", 42, "Test comment"
|
||||
)
|
||||
|
||||
assert result["body"] == "Test comment"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_pr_comments(self, gitea_provider, mock_httpx_client):
|
||||
"""Test listing PR comments."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=[
|
||||
{"id": 1, "body": "Comment 1"},
|
||||
{"id": 2, "body": "Comment 2"},
|
||||
]
|
||||
)
|
||||
|
||||
result = await gitea_provider.list_pr_comments("owner", "repo", 42)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestGiteaLabelOperations:
|
||||
"""Tests for Gitea label operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_labels(self, gitea_provider, mock_httpx_client):
|
||||
"""Test adding labels to a PR."""
|
||||
# Use side_effect to return different values for different calls
|
||||
# 1. GET labels (for "bug") - returns existing labels
|
||||
# 2. POST to create "bug" label
|
||||
# 3. GET labels (for "urgent")
|
||||
# 4. POST to create "urgent" label
|
||||
# 5. POST labels to PR
|
||||
# 6. GET issue to return final labels
|
||||
mock_responses = [
|
||||
[{"id": 1, "name": "existing"}], # GET labels (bug not found)
|
||||
{"id": 2, "name": "bug"}, # POST create bug
|
||||
[{"id": 1, "name": "existing"}, {"id": 2, "name": "bug"}], # GET labels (urgent not found)
|
||||
{"id": 3, "name": "urgent"}, # POST create urgent
|
||||
{}, # POST add labels to PR
|
||||
{"labels": [{"name": "bug"}, {"name": "urgent"}]}, # GET issue
|
||||
]
|
||||
mock_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses)
|
||||
|
||||
result = await gitea_provider.add_labels(
|
||||
"owner", "repo", 42, ["bug", "urgent"]
|
||||
)
|
||||
|
||||
# Should return updated label list
|
||||
assert isinstance(result, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_label(self, gitea_provider, mock_httpx_client):
|
||||
"""Test removing a label from a PR."""
|
||||
# Use side_effect for multiple calls
|
||||
# 1. GET labels to find the label ID
|
||||
# 2. DELETE the label from the PR
|
||||
# 3. GET issue to return remaining labels
|
||||
mock_responses = [
|
||||
[{"id": 1, "name": "bug"}], # GET labels
|
||||
{}, # DELETE label
|
||||
{"labels": []}, # GET issue
|
||||
]
|
||||
mock_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses)
|
||||
|
||||
result = await gitea_provider.remove_label(
|
||||
"owner", "repo", 42, "bug"
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
class TestGiteaReviewerOperations:
|
||||
"""Tests for Gitea reviewer operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_review(self, gitea_provider, mock_httpx_client):
|
||||
"""Test requesting review from users."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(return_value={})
|
||||
|
||||
result = await gitea_provider.request_review(
|
||||
"owner", "repo", 42, ["reviewer1", "reviewer2"]
|
||||
)
|
||||
|
||||
assert result == ["reviewer1", "reviewer2"]
|
||||
|
||||
|
||||
class TestGiteaErrorHandling:
|
||||
"""Tests for error handling in Gitea provider."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_error(self, gitea_provider, mock_httpx_client):
|
||||
"""Test handling authentication errors."""
|
||||
mock_httpx_client.request.return_value.status_code = 401
|
||||
|
||||
with pytest.raises(AuthenticationError):
|
||||
await gitea_provider._request("GET", "/user")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_denied(self, gitea_provider, mock_httpx_client):
|
||||
"""Test handling permission denied errors."""
|
||||
mock_httpx_client.request.return_value.status_code = 403
|
||||
|
||||
with pytest.raises(AuthenticationError, match="Insufficient permissions"):
|
||||
await gitea_provider._request("GET", "/protected")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error(self, gitea_provider, mock_httpx_client):
|
||||
"""Test handling general API errors."""
|
||||
mock_httpx_client.request.return_value.status_code = 500
|
||||
mock_httpx_client.request.return_value.text = "Internal Server Error"
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"message": "Server error"}
|
||||
)
|
||||
|
||||
with pytest.raises(APIError):
|
||||
await gitea_provider._request("GET", "/error")
|
||||
|
||||
|
||||
class TestGiteaPRParsing:
|
||||
"""Tests for PR data parsing."""
|
||||
|
||||
def test_parse_pr_open(self, gitea_provider, sample_pr_data):
|
||||
"""Test parsing open PR."""
|
||||
pr_info = gitea_provider._parse_pr(sample_pr_data)
|
||||
|
||||
assert pr_info.number == 42
|
||||
assert pr_info.state == PRState.OPEN
|
||||
assert pr_info.title == "Test PR"
|
||||
assert pr_info.source_branch == "feature-branch"
|
||||
assert pr_info.target_branch == "main"
|
||||
|
||||
def test_parse_pr_merged(self, gitea_provider, sample_pr_data):
|
||||
"""Test parsing merged PR."""
|
||||
sample_pr_data["merged"] = True
|
||||
sample_pr_data["merged_at"] = "2024-01-16T10:00:00Z"
|
||||
|
||||
pr_info = gitea_provider._parse_pr(sample_pr_data)
|
||||
|
||||
assert pr_info.state == PRState.MERGED
|
||||
|
||||
def test_parse_pr_closed(self, gitea_provider, sample_pr_data):
|
||||
"""Test parsing closed PR."""
|
||||
sample_pr_data["state"] = "closed"
|
||||
sample_pr_data["closed_at"] = "2024-01-16T10:00:00Z"
|
||||
|
||||
pr_info = gitea_provider._parse_pr(sample_pr_data)
|
||||
|
||||
assert pr_info.state == PRState.CLOSED
|
||||
|
||||
def test_parse_datetime_iso(self, gitea_provider):
|
||||
"""Test parsing ISO datetime strings."""
|
||||
dt = gitea_provider._parse_datetime("2024-01-15T10:30:00Z")
|
||||
|
||||
assert dt.year == 2024
|
||||
assert dt.month == 1
|
||||
assert dt.day == 15
|
||||
|
||||
def test_parse_datetime_none(self, gitea_provider):
|
||||
"""Test parsing None datetime returns now."""
|
||||
dt = gitea_provider._parse_datetime(None)
|
||||
|
||||
assert dt is not None
|
||||
assert dt.tzinfo is not None
|
||||
514
mcp-servers/git-ops/tests/test_server.py
Normal file
514
mcp-servers/git-ops/tests/test_server.py
Normal file
@@ -0,0 +1,514 @@
|
||||
"""
|
||||
Tests for the MCP server and tools.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from exceptions import ErrorCode
|
||||
|
||||
|
||||
class TestInputValidation:
|
||||
"""Tests for input validation functions."""
|
||||
|
||||
def test_validate_id_valid(self):
|
||||
"""Test valid IDs pass validation."""
|
||||
from server import _validate_id
|
||||
|
||||
assert _validate_id("test-123", "project_id") is None
|
||||
assert _validate_id("my_project", "project_id") is None
|
||||
assert _validate_id("Agent-001", "agent_id") is None
|
||||
|
||||
def test_validate_id_empty(self):
|
||||
"""Test empty ID fails validation."""
|
||||
from server import _validate_id
|
||||
|
||||
error = _validate_id("", "project_id")
|
||||
assert error is not None
|
||||
assert "required" in error.lower()
|
||||
|
||||
def test_validate_id_too_long(self):
|
||||
"""Test too-long ID fails validation."""
|
||||
from server import _validate_id
|
||||
|
||||
error = _validate_id("a" * 200, "project_id")
|
||||
assert error is not None
|
||||
assert "1-128" in error
|
||||
|
||||
def test_validate_id_invalid_chars(self):
|
||||
"""Test invalid characters fail validation."""
|
||||
from server import _validate_id
|
||||
|
||||
assert _validate_id("test@invalid", "project_id") is not None
|
||||
assert _validate_id("test!project", "project_id") is not None
|
||||
assert _validate_id("test project", "project_id") is not None
|
||||
|
||||
def test_validate_branch_valid(self):
|
||||
"""Test valid branch names."""
|
||||
from server import _validate_branch
|
||||
|
||||
assert _validate_branch("main") is None
|
||||
assert _validate_branch("feature/new-thing") is None
|
||||
assert _validate_branch("release-1.0.0") is None
|
||||
assert _validate_branch("hotfix.urgent") is None
|
||||
|
||||
def test_validate_branch_invalid(self):
|
||||
"""Test invalid branch names."""
|
||||
from server import _validate_branch
|
||||
|
||||
assert _validate_branch("") is not None
|
||||
assert _validate_branch("a" * 300) is not None
|
||||
|
||||
def test_validate_url_valid(self):
|
||||
"""Test valid repository URLs."""
|
||||
from server import _validate_url
|
||||
|
||||
assert _validate_url("https://github.com/owner/repo.git") is None
|
||||
assert _validate_url("https://gitea.example.com/owner/repo") is None
|
||||
assert _validate_url("git@github.com:owner/repo.git") is None
|
||||
|
||||
def test_validate_url_invalid(self):
|
||||
"""Test invalid repository URLs."""
|
||||
from server import _validate_url
|
||||
|
||||
assert _validate_url("") is not None
|
||||
assert _validate_url("not-a-url") is not None
|
||||
assert _validate_url("ftp://invalid.com/repo") is not None
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
"""Tests for health check endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_structure(self):
|
||||
"""Test health check returns proper structure."""
|
||||
from server import health_check
|
||||
|
||||
with patch("server._gitea_provider", None), \
|
||||
patch("server._workspace_manager", None):
|
||||
result = await health_check()
|
||||
|
||||
assert "status" in result
|
||||
assert "service" in result
|
||||
assert "version" in result
|
||||
assert "timestamp" in result
|
||||
assert "dependencies" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_no_providers(self):
|
||||
"""Test health check without providers configured."""
|
||||
from server import health_check
|
||||
|
||||
with patch("server._gitea_provider", None), \
|
||||
patch("server._workspace_manager", None):
|
||||
result = await health_check()
|
||||
|
||||
assert result["dependencies"]["gitea"] == "not configured"
|
||||
|
||||
|
||||
class TestToolRegistry:
|
||||
"""Tests for tool registration."""
|
||||
|
||||
def test_tool_registry_populated(self):
|
||||
"""Test that tools are registered."""
|
||||
from server import _tool_registry
|
||||
|
||||
assert len(_tool_registry) > 0
|
||||
assert "clone_repository" in _tool_registry
|
||||
assert "git_status" in _tool_registry
|
||||
assert "create_branch" in _tool_registry
|
||||
assert "commit" in _tool_registry
|
||||
|
||||
def test_tool_schema_structure(self):
|
||||
"""Test tool schemas have proper structure."""
|
||||
from server import _tool_registry
|
||||
|
||||
for name, info in _tool_registry.items():
|
||||
assert "func" in info
|
||||
assert "description" in info
|
||||
assert "schema" in info
|
||||
assert info["schema"]["type"] == "object"
|
||||
assert "properties" in info["schema"]
|
||||
|
||||
|
||||
class TestCloneRepository:
|
||||
"""Tests for clone_repository tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_invalid_project_id(self):
|
||||
"""Test clone with invalid project ID."""
|
||||
from server import clone_repository
|
||||
|
||||
# Access the underlying function via .fn
|
||||
result = await clone_repository.fn(
|
||||
project_id="invalid@id",
|
||||
agent_id="agent-1",
|
||||
repo_url="https://github.com/owner/repo.git",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "project_id" in result["error"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_invalid_repo_url(self):
|
||||
"""Test clone with invalid repo URL."""
|
||||
from server import clone_repository
|
||||
|
||||
result = await clone_repository.fn(
|
||||
project_id="valid-project",
|
||||
agent_id="agent-1",
|
||||
repo_url="not-a-valid-url",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "url" in result["error"].lower()
|
||||
|
||||
|
||||
class TestGitStatus:
|
||||
"""Tests for git_status tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_workspace_not_found(self):
|
||||
"""Test status when workspace doesn't exist."""
|
||||
from server import git_status
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await git_status.fn(
|
||||
project_id="nonexistent",
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["code"] == ErrorCode.WORKSPACE_NOT_FOUND.value
|
||||
|
||||
|
||||
class TestBranchOperations:
|
||||
"""Tests for branch operation tools."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_branch_invalid_name(self):
|
||||
"""Test creating branch with invalid name."""
|
||||
from server import create_branch
|
||||
|
||||
result = await create_branch.fn(
|
||||
project_id="test-project",
|
||||
agent_id="agent-1",
|
||||
branch_name="", # Invalid
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_branches_workspace_not_found(self):
|
||||
"""Test listing branches when workspace doesn't exist."""
|
||||
from server import list_branches
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await list_branches.fn(
|
||||
project_id="nonexistent",
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkout_invalid_project(self):
|
||||
"""Test checkout with invalid project ID."""
|
||||
from server import checkout
|
||||
|
||||
result = await checkout.fn(
|
||||
project_id="inv@lid",
|
||||
agent_id="agent-1",
|
||||
ref="main",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
class TestCommitOperations:
|
||||
"""Tests for commit operation tools."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commit_invalid_project(self):
|
||||
"""Test commit with invalid project ID."""
|
||||
from server import commit
|
||||
|
||||
result = await commit.fn(
|
||||
project_id="inv@lid",
|
||||
agent_id="agent-1",
|
||||
message="Test commit",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
class TestPushPullOperations:
|
||||
"""Tests for push/pull operation tools."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_workspace_not_found(self):
|
||||
"""Test push when workspace doesn't exist."""
|
||||
from server import push
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await push.fn(
|
||||
project_id="nonexistent",
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pull_workspace_not_found(self):
|
||||
"""Test pull when workspace doesn't exist."""
|
||||
from server import pull
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await pull.fn(
|
||||
project_id="nonexistent",
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
class TestDiffLogOperations:
|
||||
"""Tests for diff and log operation tools."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diff_workspace_not_found(self):
|
||||
"""Test diff when workspace doesn't exist."""
|
||||
from server import diff
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await diff.fn(
|
||||
project_id="nonexistent",
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_workspace_not_found(self):
|
||||
"""Test log when workspace doesn't exist."""
|
||||
from server import log
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await log.fn(
|
||||
project_id="nonexistent",
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
class TestPROperations:
|
||||
"""Tests for pull request operation tools."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pr_no_repo_url(self):
|
||||
"""Test create PR when workspace has no repo URL."""
|
||||
from models import WorkspaceInfo, WorkspaceState
|
||||
from server import create_pull_request
|
||||
|
||||
mock_workspace = WorkspaceInfo(
|
||||
project_id="test-project",
|
||||
path="/tmp/test",
|
||||
state=WorkspaceState.READY,
|
||||
repo_url=None, # No repo URL
|
||||
)
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await create_pull_request.fn(
|
||||
project_id="test-project",
|
||||
agent_id="agent-1",
|
||||
title="Test PR",
|
||||
source_branch="feature",
|
||||
target_branch="main",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "repository URL" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_prs_invalid_state(self):
|
||||
"""Test list PRs with invalid state filter."""
|
||||
from models import WorkspaceInfo, WorkspaceState
|
||||
from server import list_pull_requests
|
||||
|
||||
mock_workspace = WorkspaceInfo(
|
||||
project_id="test-project",
|
||||
path="/tmp/test",
|
||||
state=WorkspaceState.READY,
|
||||
repo_url="https://gitea.test.com/owner/repo.git",
|
||||
)
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
|
||||
|
||||
mock_provider = AsyncMock()
|
||||
mock_provider.parse_repo_url = MagicMock(return_value=("owner", "repo"))
|
||||
|
||||
with patch("server._workspace_manager", mock_manager), \
|
||||
patch("server._get_provider_for_url", return_value=mock_provider):
|
||||
result = await list_pull_requests.fn(
|
||||
project_id="test-project",
|
||||
agent_id="agent-1",
|
||||
state="invalid-state",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Invalid state" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_pr_invalid_strategy(self):
|
||||
"""Test merge PR with invalid strategy."""
|
||||
from models import WorkspaceInfo, WorkspaceState
|
||||
from server import merge_pull_request
|
||||
|
||||
mock_workspace = WorkspaceInfo(
|
||||
project_id="test-project",
|
||||
path="/tmp/test",
|
||||
state=WorkspaceState.READY,
|
||||
repo_url="https://gitea.test.com/owner/repo.git",
|
||||
)
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
|
||||
|
||||
mock_provider = AsyncMock()
|
||||
mock_provider.parse_repo_url = MagicMock(return_value=("owner", "repo"))
|
||||
|
||||
with patch("server._workspace_manager", mock_manager), \
|
||||
patch("server._get_provider_for_url", return_value=mock_provider):
|
||||
result = await merge_pull_request.fn(
|
||||
project_id="test-project",
|
||||
agent_id="agent-1",
|
||||
pr_number=42,
|
||||
merge_strategy="invalid-strategy",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Invalid strategy" in result["error"]
|
||||
|
||||
|
||||
class TestWorkspaceOperations:
|
||||
"""Tests for workspace operation tools."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_workspace_not_found(self):
|
||||
"""Test get workspace when it doesn't exist."""
|
||||
from server import get_workspace
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await get_workspace.fn(
|
||||
project_id="nonexistent",
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_workspace_success(self):
|
||||
"""Test successful workspace locking."""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from models import WorkspaceInfo, WorkspaceState
|
||||
from server import lock_workspace
|
||||
|
||||
mock_workspace = WorkspaceInfo(
|
||||
project_id="test-project",
|
||||
path="/tmp/test",
|
||||
state=WorkspaceState.LOCKED,
|
||||
lock_holder="agent-1",
|
||||
lock_expires=datetime.now(UTC) + timedelta(seconds=300),
|
||||
)
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.lock_workspace = AsyncMock(return_value=True)
|
||||
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await lock_workspace.fn(
|
||||
project_id="test-project",
|
||||
agent_id="agent-1",
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["lock_holder"] == "agent-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unlock_workspace_success(self):
|
||||
"""Test successful workspace unlocking."""
|
||||
from server import unlock_workspace
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.unlock_workspace = AsyncMock(return_value=True)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await unlock_workspace.fn(
|
||||
project_id="test-project",
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
class TestJSONRPCEndpoint:
|
||||
"""Tests for the JSON-RPC endpoint."""
|
||||
|
||||
def test_python_type_to_json_schema_str(self):
|
||||
"""Test string type conversion."""
|
||||
from server import _python_type_to_json_schema
|
||||
|
||||
result = _python_type_to_json_schema(str)
|
||||
assert result["type"] == "string"
|
||||
|
||||
def test_python_type_to_json_schema_int(self):
|
||||
"""Test int type conversion."""
|
||||
from server import _python_type_to_json_schema
|
||||
|
||||
result = _python_type_to_json_schema(int)
|
||||
assert result["type"] == "integer"
|
||||
|
||||
def test_python_type_to_json_schema_bool(self):
|
||||
"""Test bool type conversion."""
|
||||
from server import _python_type_to_json_schema
|
||||
|
||||
result = _python_type_to_json_schema(bool)
|
||||
assert result["type"] == "boolean"
|
||||
|
||||
def test_python_type_to_json_schema_list(self):
|
||||
"""Test list type conversion."""
|
||||
|
||||
from server import _python_type_to_json_schema
|
||||
|
||||
result = _python_type_to_json_schema(list[str])
|
||||
assert result["type"] == "array"
|
||||
assert result["items"]["type"] == "string"
|
||||
334
mcp-servers/git-ops/tests/test_workspace.py
Normal file
334
mcp-servers/git-ops/tests/test_workspace.py
Normal file
@@ -0,0 +1,334 @@
|
||||
"""
|
||||
Tests for the workspace management module.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from exceptions import WorkspaceLockedError, WorkspaceNotFoundError
|
||||
from models import WorkspaceState
|
||||
from workspace import FileLockManager, WorkspaceLock
|
||||
|
||||
|
||||
class TestWorkspaceManager:
|
||||
"""Tests for WorkspaceManager."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_workspace(self, workspace_manager, valid_project_id):
|
||||
"""Test creating a new workspace."""
|
||||
workspace = await workspace_manager.create_workspace(valid_project_id)
|
||||
|
||||
assert workspace.project_id == valid_project_id
|
||||
assert workspace.state == WorkspaceState.INITIALIZING
|
||||
assert Path(workspace.path).exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_workspace_with_repo_url(self, workspace_manager, valid_project_id, sample_repo_url):
|
||||
"""Test creating workspace with repository URL."""
|
||||
workspace = await workspace_manager.create_workspace(
|
||||
valid_project_id, repo_url=sample_repo_url
|
||||
)
|
||||
|
||||
assert workspace.repo_url == sample_repo_url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_workspace(self, workspace_manager, valid_project_id):
|
||||
"""Test getting an existing workspace."""
|
||||
# Create first
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
|
||||
# Get it
|
||||
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||
|
||||
assert workspace is not None
|
||||
assert workspace.project_id == valid_project_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_workspace_not_found(self, workspace_manager):
|
||||
"""Test getting non-existent workspace."""
|
||||
workspace = await workspace_manager.get_workspace("nonexistent")
|
||||
assert workspace is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_workspace(self, workspace_manager, valid_project_id):
|
||||
"""Test deleting a workspace."""
|
||||
# Create first
|
||||
workspace = await workspace_manager.create_workspace(valid_project_id)
|
||||
workspace_path = Path(workspace.path)
|
||||
assert workspace_path.exists()
|
||||
|
||||
# Delete
|
||||
result = await workspace_manager.delete_workspace(valid_project_id)
|
||||
|
||||
assert result is True
|
||||
assert not workspace_path.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_workspace(self, workspace_manager):
|
||||
"""Test deleting non-existent workspace returns True."""
|
||||
result = await workspace_manager.delete_workspace("nonexistent")
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_workspaces(self, workspace_manager):
|
||||
"""Test listing workspaces."""
|
||||
# Create multiple workspaces
|
||||
await workspace_manager.create_workspace("project-1")
|
||||
await workspace_manager.create_workspace("project-2")
|
||||
await workspace_manager.create_workspace("project-3")
|
||||
|
||||
workspaces = await workspace_manager.list_workspaces()
|
||||
|
||||
assert len(workspaces) >= 3
|
||||
project_ids = [w.project_id for w in workspaces]
|
||||
assert "project-1" in project_ids
|
||||
assert "project-2" in project_ids
|
||||
assert "project-3" in project_ids
|
||||
|
||||
|
||||
class TestWorkspaceLocking:
|
||||
"""Tests for workspace locking."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_workspace(self, workspace_manager, valid_project_id, valid_agent_id):
|
||||
"""Test locking a workspace."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
|
||||
result = await workspace_manager.lock_workspace(
|
||||
valid_project_id, valid_agent_id, timeout=60
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||
assert workspace.state == WorkspaceState.LOCKED
|
||||
assert workspace.lock_holder == valid_agent_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_already_locked(self, workspace_manager, valid_project_id):
|
||||
"""Test locking already-locked workspace by different holder."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
await workspace_manager.lock_workspace(valid_project_id, "agent-1", timeout=60)
|
||||
|
||||
with pytest.raises(WorkspaceLockedError):
|
||||
await workspace_manager.lock_workspace(valid_project_id, "agent-2", timeout=60)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_same_holder(self, workspace_manager, valid_project_id, valid_agent_id):
|
||||
"""Test re-locking by same holder extends lock."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
await workspace_manager.lock_workspace(valid_project_id, valid_agent_id, timeout=60)
|
||||
|
||||
# Same holder can re-lock
|
||||
result = await workspace_manager.lock_workspace(
|
||||
valid_project_id, valid_agent_id, timeout=120
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unlock_workspace(self, workspace_manager, valid_project_id, valid_agent_id):
|
||||
"""Test unlocking a workspace."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
await workspace_manager.lock_workspace(valid_project_id, valid_agent_id)
|
||||
|
||||
result = await workspace_manager.unlock_workspace(valid_project_id, valid_agent_id)
|
||||
|
||||
assert result is True
|
||||
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||
assert workspace.state == WorkspaceState.READY
|
||||
assert workspace.lock_holder is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unlock_wrong_holder(self, workspace_manager, valid_project_id):
|
||||
"""Test unlock fails with wrong holder."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
await workspace_manager.lock_workspace(valid_project_id, "agent-1")
|
||||
|
||||
with pytest.raises(WorkspaceLockedError):
|
||||
await workspace_manager.unlock_workspace(valid_project_id, "agent-2")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_force_unlock(self, workspace_manager, valid_project_id):
|
||||
"""Test force unlock works regardless of holder."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
await workspace_manager.lock_workspace(valid_project_id, "agent-1")
|
||||
|
||||
result = await workspace_manager.unlock_workspace(
|
||||
valid_project_id, "admin", force=True
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_nonexistent_workspace(self, workspace_manager, valid_agent_id):
|
||||
"""Test locking non-existent workspace raises error."""
|
||||
with pytest.raises(WorkspaceNotFoundError):
|
||||
await workspace_manager.lock_workspace("nonexistent", valid_agent_id)
|
||||
|
||||
|
||||
class TestWorkspaceLockContextManager:
|
||||
"""Tests for WorkspaceLock context manager."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_context_manager(self, workspace_manager, valid_project_id, valid_agent_id):
|
||||
"""Test using WorkspaceLock as context manager."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
|
||||
async with WorkspaceLock(
|
||||
workspace_manager, valid_project_id, valid_agent_id
|
||||
) as lock:
|
||||
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||
assert workspace.state == WorkspaceState.LOCKED
|
||||
|
||||
# After exiting context, should be unlocked
|
||||
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||
assert workspace.lock_holder is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_context_manager_error(self, workspace_manager, valid_project_id, valid_agent_id):
|
||||
"""Test WorkspaceLock releases on exception."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
|
||||
try:
|
||||
async with WorkspaceLock(
|
||||
workspace_manager, valid_project_id, valid_agent_id
|
||||
):
|
||||
raise ValueError("Test error")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||
assert workspace.lock_holder is None
|
||||
|
||||
|
||||
class TestWorkspaceMetadata:
|
||||
"""Tests for workspace metadata operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_touch_workspace(self, workspace_manager, valid_project_id):
|
||||
"""Test updating workspace access time."""
|
||||
workspace = await workspace_manager.create_workspace(valid_project_id)
|
||||
original_time = workspace.last_accessed
|
||||
|
||||
await workspace_manager.touch_workspace(valid_project_id)
|
||||
|
||||
updated = await workspace_manager.get_workspace(valid_project_id)
|
||||
assert updated.last_accessed >= original_time
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_workspace_branch(self, workspace_manager, valid_project_id):
|
||||
"""Test updating workspace branch."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
|
||||
await workspace_manager.update_workspace_branch(valid_project_id, "feature-branch")
|
||||
|
||||
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||
assert workspace.current_branch == "feature-branch"
|
||||
|
||||
|
||||
class TestWorkspaceSize:
|
||||
"""Tests for workspace size management."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_size_within_limit(self, workspace_manager, valid_project_id):
|
||||
"""Test size check passes for small workspace."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
|
||||
# Should not raise
|
||||
result = await workspace_manager.check_size_limit(valid_project_id)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_total_size(self, workspace_manager, valid_project_id):
|
||||
"""Test getting total workspace size."""
|
||||
workspace = await workspace_manager.create_workspace(valid_project_id)
|
||||
|
||||
# Add some content
|
||||
content_file = Path(workspace.path) / "content.txt"
|
||||
content_file.write_text("x" * 1000)
|
||||
|
||||
total_size = await workspace_manager.get_total_size()
|
||||
assert total_size >= 1000
|
||||
|
||||
|
||||
class TestFileLockManager:
|
||||
"""Tests for file-based locking."""
|
||||
|
||||
def test_acquire_lock(self, temp_dir):
|
||||
"""Test acquiring a file lock."""
|
||||
manager = FileLockManager(temp_dir / "locks")
|
||||
|
||||
result = manager.acquire("test-key")
|
||||
assert result is True
|
||||
|
||||
# Cleanup
|
||||
manager.release("test-key")
|
||||
|
||||
def test_release_lock(self, temp_dir):
|
||||
"""Test releasing a file lock."""
|
||||
manager = FileLockManager(temp_dir / "locks")
|
||||
manager.acquire("test-key")
|
||||
|
||||
result = manager.release("test-key")
|
||||
assert result is True
|
||||
|
||||
def test_is_locked(self, temp_dir):
|
||||
"""Test checking if locked."""
|
||||
manager = FileLockManager(temp_dir / "locks")
|
||||
|
||||
assert manager.is_locked("test-key") is False
|
||||
|
||||
manager.acquire("test-key")
|
||||
assert manager.is_locked("test-key") is True
|
||||
|
||||
manager.release("test-key")
|
||||
|
||||
def test_release_nonexistent_lock(self, temp_dir):
|
||||
"""Test releasing a lock that doesn't exist."""
|
||||
manager = FileLockManager(temp_dir / "locks")
|
||||
|
||||
# Should not raise
|
||||
result = manager.release("nonexistent")
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestWorkspaceCleanup:
|
||||
"""Tests for workspace cleanup operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_stale_workspaces(self, workspace_manager, test_settings):
|
||||
"""Test cleaning up stale workspaces."""
|
||||
# Create workspace
|
||||
workspace = await workspace_manager.create_workspace("stale-project")
|
||||
|
||||
# Manually set it as stale by updating metadata
|
||||
await workspace_manager._update_metadata(
|
||||
"stale-project",
|
||||
last_accessed=(datetime.now(UTC) - timedelta(days=30)).isoformat(),
|
||||
)
|
||||
|
||||
# Run cleanup
|
||||
cleaned = await workspace_manager.cleanup_stale_workspaces()
|
||||
|
||||
assert cleaned >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_locked_workspace_blocked(self, workspace_manager, valid_project_id, valid_agent_id):
|
||||
"""Test deleting locked workspace is blocked without force."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
await workspace_manager.lock_workspace(valid_project_id, valid_agent_id)
|
||||
|
||||
with pytest.raises(WorkspaceLockedError):
|
||||
await workspace_manager.delete_workspace(valid_project_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_locked_workspace_force(self, workspace_manager, valid_project_id, valid_agent_id):
|
||||
"""Test force deleting locked workspace."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
await workspace_manager.lock_workspace(valid_project_id, valid_agent_id)
|
||||
|
||||
result = await workspace_manager.delete_workspace(valid_project_id, force=True)
|
||||
assert result is True
|
||||
1853
mcp-servers/git-ops/uv.lock
generated
Normal file
1853
mcp-servers/git-ops/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
608
mcp-servers/git-ops/workspace.py
Normal file
608
mcp-servers/git-ops/workspace.py
Normal file
@@ -0,0 +1,608 @@
|
||||
"""
|
||||
Workspace management for Git Operations MCP Server.
|
||||
|
||||
Handles isolated workspaces for each project, including creation,
|
||||
locking, cleanup, and size management.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiofiles
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
from config import Settings, get_settings
|
||||
from exceptions import (
|
||||
WorkspaceLockedError,
|
||||
WorkspaceNotFoundError,
|
||||
WorkspaceSizeExceededError,
|
||||
)
|
||||
from models import WorkspaceInfo, WorkspaceState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Metadata file name
|
||||
WORKSPACE_METADATA_FILE = ".syndarix-workspace.json"
|
||||
|
||||
|
||||
class WorkspaceManager:
|
||||
"""
|
||||
Manages git workspaces for projects.
|
||||
|
||||
Each project gets an isolated workspace directory for git operations.
|
||||
Supports distributed locking via Redis or local file locks.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: Settings | None = None) -> None:
|
||||
"""
|
||||
Initialize WorkspaceManager.
|
||||
|
||||
Args:
|
||||
settings: Optional settings override
|
||||
"""
|
||||
self.settings = settings or get_settings()
|
||||
self.base_path = self.settings.workspace_base_path
|
||||
self._ensure_base_path()
|
||||
|
||||
def _ensure_base_path(self) -> None:
|
||||
"""Ensure the base workspace directory exists."""
|
||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_workspace_path(self, project_id: str) -> Path:
|
||||
"""Get the path for a project workspace."""
|
||||
# Sanitize project ID for filesystem
|
||||
safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_id)
|
||||
return self.base_path / safe_id
|
||||
|
||||
def _get_lock_path(self, project_id: str) -> Path:
|
||||
"""Get the lock file path for a workspace."""
|
||||
return self._get_workspace_path(project_id) / ".lock"
|
||||
|
||||
def _get_metadata_path(self, project_id: str) -> Path:
|
||||
"""Get the metadata file path for a workspace."""
|
||||
return self._get_workspace_path(project_id) / WORKSPACE_METADATA_FILE
|
||||
|
||||
async def get_workspace(self, project_id: str) -> WorkspaceInfo | None:
|
||||
"""
|
||||
Get workspace info for a project.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
|
||||
Returns:
|
||||
WorkspaceInfo or None if not found
|
||||
"""
|
||||
workspace_path = self._get_workspace_path(project_id)
|
||||
|
||||
if not workspace_path.exists():
|
||||
return None
|
||||
|
||||
# Load metadata
|
||||
metadata = await self._load_metadata(project_id)
|
||||
|
||||
# Calculate size
|
||||
size_bytes = await self._calculate_size(workspace_path)
|
||||
|
||||
# Check lock status
|
||||
lock_holder = None
|
||||
lock_expires = None
|
||||
if metadata:
|
||||
lock_holder = metadata.get("lock_holder")
|
||||
if metadata.get("lock_expires"):
|
||||
lock_expires = datetime.fromisoformat(metadata["lock_expires"])
|
||||
# Clear expired locks
|
||||
if lock_expires < datetime.now(UTC):
|
||||
lock_holder = None
|
||||
lock_expires = None
|
||||
|
||||
# Determine state
|
||||
state = WorkspaceState.READY
|
||||
if lock_holder:
|
||||
state = WorkspaceState.LOCKED
|
||||
|
||||
# Check if stale
|
||||
last_accessed = datetime.now(UTC)
|
||||
if metadata and metadata.get("last_accessed"):
|
||||
last_accessed = datetime.fromisoformat(metadata["last_accessed"])
|
||||
stale_threshold = datetime.now(UTC) - timedelta(
|
||||
days=self.settings.workspace_stale_days
|
||||
)
|
||||
if last_accessed < stale_threshold:
|
||||
state = WorkspaceState.STALE
|
||||
|
||||
return WorkspaceInfo(
|
||||
project_id=project_id,
|
||||
path=str(workspace_path),
|
||||
state=state,
|
||||
repo_url=metadata.get("repo_url") if metadata else None,
|
||||
current_branch=metadata.get("current_branch") if metadata else None,
|
||||
last_accessed=last_accessed,
|
||||
size_bytes=size_bytes,
|
||||
lock_holder=lock_holder,
|
||||
lock_expires=lock_expires,
|
||||
)
|
||||
|
||||
async def create_workspace(
|
||||
self,
|
||||
project_id: str,
|
||||
repo_url: str | None = None,
|
||||
) -> WorkspaceInfo:
|
||||
"""
|
||||
Create or get a workspace for a project.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
repo_url: Optional repository URL
|
||||
|
||||
Returns:
|
||||
WorkspaceInfo for the workspace
|
||||
"""
|
||||
workspace_path = self._get_workspace_path(project_id)
|
||||
|
||||
if workspace_path.exists():
|
||||
# Workspace already exists, update metadata
|
||||
await self._update_metadata(project_id, repo_url=repo_url)
|
||||
workspace = await self.get_workspace(project_id)
|
||||
if workspace:
|
||||
return workspace
|
||||
|
||||
# Create workspace directory
|
||||
workspace_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create initial metadata
|
||||
metadata = {
|
||||
"project_id": project_id,
|
||||
"repo_url": repo_url,
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"last_accessed": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
await self._save_metadata(project_id, metadata)
|
||||
|
||||
return WorkspaceInfo(
|
||||
project_id=project_id,
|
||||
path=str(workspace_path),
|
||||
state=WorkspaceState.INITIALIZING,
|
||||
repo_url=repo_url,
|
||||
last_accessed=datetime.now(UTC),
|
||||
size_bytes=0,
|
||||
)
|
||||
|
||||
async def delete_workspace(self, project_id: str, force: bool = False) -> bool:
|
||||
"""
|
||||
Delete a workspace.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
force: Force delete even if locked
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
workspace_path = self._get_workspace_path(project_id)
|
||||
|
||||
if not workspace_path.exists():
|
||||
return True
|
||||
|
||||
# Check lock
|
||||
if not force:
|
||||
workspace = await self.get_workspace(project_id)
|
||||
if workspace and workspace.state == WorkspaceState.LOCKED:
|
||||
raise WorkspaceLockedError(project_id, workspace.lock_holder)
|
||||
|
||||
try:
|
||||
# Use shutil.rmtree for robust deletion
|
||||
shutil.rmtree(workspace_path)
|
||||
logger.info(f"Deleted workspace for project: {project_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete workspace {project_id}: {e}")
|
||||
return False
|
||||
|
||||
async def lock_workspace(
|
||||
self,
|
||||
project_id: str,
|
||||
holder: str,
|
||||
timeout: int | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Acquire a lock on a workspace.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
holder: Lock holder identifier (agent_id)
|
||||
timeout: Lock timeout in seconds
|
||||
|
||||
Returns:
|
||||
True if lock acquired
|
||||
|
||||
Raises:
|
||||
WorkspaceNotFoundError: If workspace doesn't exist
|
||||
WorkspaceLockedError: If already locked by another
|
||||
"""
|
||||
workspace = await self.get_workspace(project_id)
|
||||
|
||||
if workspace is None:
|
||||
raise WorkspaceNotFoundError(project_id)
|
||||
|
||||
# Check if already locked by someone else
|
||||
if (
|
||||
workspace.state == WorkspaceState.LOCKED
|
||||
and workspace.lock_holder != holder
|
||||
):
|
||||
# Check if lock expired
|
||||
if workspace.lock_expires and workspace.lock_expires > datetime.now(UTC):
|
||||
raise WorkspaceLockedError(project_id, workspace.lock_holder)
|
||||
|
||||
# Calculate lock expiry
|
||||
lock_timeout = timeout or self.settings.workspace_lock_timeout
|
||||
lock_expires = datetime.now(UTC) + timedelta(seconds=lock_timeout)
|
||||
|
||||
# Update metadata with lock info
|
||||
await self._update_metadata(
|
||||
project_id,
|
||||
lock_holder=holder,
|
||||
lock_expires=lock_expires.isoformat(),
|
||||
)
|
||||
|
||||
logger.info(f"Workspace {project_id} locked by {holder}")
|
||||
return True
|
||||
|
||||
async def unlock_workspace(
|
||||
self,
|
||||
project_id: str,
|
||||
holder: str,
|
||||
force: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Release a lock on a workspace.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
holder: Lock holder identifier
|
||||
force: Force unlock regardless of holder
|
||||
|
||||
Returns:
|
||||
True if unlocked
|
||||
"""
|
||||
workspace = await self.get_workspace(project_id)
|
||||
|
||||
if workspace is None:
|
||||
raise WorkspaceNotFoundError(project_id)
|
||||
|
||||
# Verify holder
|
||||
if (
|
||||
not force
|
||||
and workspace.lock_holder
|
||||
and workspace.lock_holder != holder
|
||||
):
|
||||
raise WorkspaceLockedError(project_id, workspace.lock_holder)
|
||||
|
||||
# Clear lock
|
||||
await self._update_metadata(
|
||||
project_id,
|
||||
lock_holder=None,
|
||||
lock_expires=None,
|
||||
)
|
||||
|
||||
logger.info(f"Workspace {project_id} unlocked by {holder}")
|
||||
return True
|
||||
|
||||
async def touch_workspace(self, project_id: str) -> None:
|
||||
"""
|
||||
Update last accessed time for a workspace.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
"""
|
||||
await self._update_metadata(
|
||||
project_id,
|
||||
last_accessed=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
|
||||
async def update_workspace_branch(
|
||||
self,
|
||||
project_id: str,
|
||||
branch: str,
|
||||
) -> None:
|
||||
"""
|
||||
Update the current branch in workspace metadata.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
branch: Current branch name
|
||||
"""
|
||||
await self._update_metadata(
|
||||
project_id,
|
||||
current_branch=branch,
|
||||
last_accessed=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
|
||||
async def check_size_limit(self, project_id: str) -> bool:
|
||||
"""
|
||||
Check if workspace exceeds size limit.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
|
||||
Returns:
|
||||
True if within limits
|
||||
|
||||
Raises:
|
||||
WorkspaceSizeExceededError: If size exceeds limit
|
||||
"""
|
||||
workspace_path = self._get_workspace_path(project_id)
|
||||
|
||||
if not workspace_path.exists():
|
||||
return True
|
||||
|
||||
size_bytes = await self._calculate_size(workspace_path)
|
||||
size_gb = size_bytes / (1024 ** 3)
|
||||
max_size_gb = self.settings.workspace_max_size_gb
|
||||
|
||||
if size_gb > max_size_gb:
|
||||
raise WorkspaceSizeExceededError(project_id, size_gb, max_size_gb)
|
||||
|
||||
return True
|
||||
|
||||
async def list_workspaces(
|
||||
self,
|
||||
include_stale: bool = False,
|
||||
) -> list[WorkspaceInfo]:
|
||||
"""
|
||||
List all workspaces.
|
||||
|
||||
Args:
|
||||
include_stale: Include stale workspaces
|
||||
|
||||
Returns:
|
||||
List of WorkspaceInfo
|
||||
"""
|
||||
workspaces = []
|
||||
|
||||
if not self.base_path.exists():
|
||||
return workspaces
|
||||
|
||||
for entry in self.base_path.iterdir():
|
||||
if entry.is_dir() and not entry.name.startswith("."):
|
||||
# Extract project_id from directory name
|
||||
workspace = await self.get_workspace(entry.name)
|
||||
if workspace:
|
||||
if not include_stale and workspace.state == WorkspaceState.STALE:
|
||||
continue
|
||||
workspaces.append(workspace)
|
||||
|
||||
return workspaces
|
||||
|
||||
async def cleanup_stale_workspaces(self) -> int:
|
||||
"""
|
||||
Clean up stale workspaces.
|
||||
|
||||
Returns:
|
||||
Number of workspaces cleaned up
|
||||
"""
|
||||
cleaned = 0
|
||||
workspaces = await self.list_workspaces(include_stale=True)
|
||||
|
||||
for workspace in workspaces:
|
||||
if workspace.state == WorkspaceState.STALE:
|
||||
try:
|
||||
await self.delete_workspace(workspace.project_id, force=True)
|
||||
cleaned += 1
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to cleanup stale workspace {workspace.project_id}: {e}"
|
||||
)
|
||||
|
||||
if cleaned > 0:
|
||||
logger.info(f"Cleaned up {cleaned} stale workspaces")
|
||||
|
||||
return cleaned
|
||||
|
||||
async def get_total_size(self) -> int:
|
||||
"""
|
||||
Get total size of all workspaces.
|
||||
|
||||
Returns:
|
||||
Total size in bytes
|
||||
"""
|
||||
return await self._calculate_size(self.base_path)
|
||||
|
||||
# Private methods
|
||||
|
||||
async def _load_metadata(self, project_id: str) -> dict[str, Any] | None:
|
||||
"""Load workspace metadata from file."""
|
||||
metadata_path = self._get_metadata_path(project_id)
|
||||
|
||||
if not metadata_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
async with aiofiles.open(metadata_path) as f:
|
||||
content = await f.read()
|
||||
return json.loads(content)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load metadata for {project_id}: {e}")
|
||||
return None
|
||||
|
||||
async def _save_metadata(
|
||||
self,
|
||||
project_id: str,
|
||||
metadata: dict[str, Any],
|
||||
) -> None:
|
||||
"""Save workspace metadata to file."""
|
||||
metadata_path = self._get_metadata_path(project_id)
|
||||
|
||||
# Ensure parent directory exists
|
||||
metadata_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
async with aiofiles.open(metadata_path, "w") as f:
|
||||
await f.write(json.dumps(metadata, indent=2))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save metadata for {project_id}: {e}")
|
||||
|
||||
async def _update_metadata(
|
||||
self,
|
||||
project_id: str,
|
||||
**updates: Any,
|
||||
) -> None:
|
||||
"""Update specific fields in workspace metadata."""
|
||||
metadata = await self._load_metadata(project_id) or {}
|
||||
|
||||
# Handle None values (to clear fields)
|
||||
for key, value in updates.items():
|
||||
if value is None:
|
||||
metadata.pop(key, None)
|
||||
else:
|
||||
metadata[key] = value
|
||||
|
||||
await self._save_metadata(project_id, metadata)
|
||||
|
||||
async def _calculate_size(self, path: Path) -> int:
|
||||
"""Calculate total size of a directory."""
|
||||
|
||||
def _calc_size() -> int:
|
||||
total = 0
|
||||
try:
|
||||
for entry in path.rglob("*"):
|
||||
if entry.is_file():
|
||||
try:
|
||||
total += entry.stat().st_size
|
||||
except OSError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return total
|
||||
|
||||
# Run in executor for async compatibility
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, _calc_size)
|
||||
|
||||
|
||||
class WorkspaceLock:
|
||||
"""
|
||||
Context manager for workspace locking.
|
||||
|
||||
Provides automatic locking/unlocking with proper cleanup.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: WorkspaceManager,
|
||||
project_id: str,
|
||||
holder: str,
|
||||
timeout: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize workspace lock.
|
||||
|
||||
Args:
|
||||
manager: WorkspaceManager instance
|
||||
project_id: Project identifier
|
||||
holder: Lock holder identifier
|
||||
timeout: Lock timeout in seconds
|
||||
"""
|
||||
self.manager = manager
|
||||
self.project_id = project_id
|
||||
self.holder = holder
|
||||
self.timeout = timeout
|
||||
self._acquired = False
|
||||
|
||||
async def __aenter__(self) -> "WorkspaceLock":
|
||||
"""Acquire lock on enter."""
|
||||
await self.manager.lock_workspace(
|
||||
self.project_id,
|
||||
self.holder,
|
||||
self.timeout,
|
||||
)
|
||||
self._acquired = True
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Release lock on exit."""
|
||||
if self._acquired:
|
||||
try:
|
||||
await self.manager.unlock_workspace(
|
||||
self.project_id,
|
||||
self.holder,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to release lock for {self.project_id}: {e}"
|
||||
)
|
||||
|
||||
|
||||
class FileLockManager:
|
||||
"""
|
||||
File-based locking for single-instance deployments.
|
||||
|
||||
Uses filelock for local locking when Redis is not available.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_dir: Path) -> None:
|
||||
"""
|
||||
Initialize file lock manager.
|
||||
|
||||
Args:
|
||||
lock_dir: Directory for lock files
|
||||
"""
|
||||
self.lock_dir = lock_dir
|
||||
self.lock_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._locks: dict[str, FileLock] = {}
|
||||
|
||||
def _get_lock(self, key: str) -> FileLock:
|
||||
"""Get or create a file lock for a key."""
|
||||
if key not in self._locks:
|
||||
lock_path = self.lock_dir / f"{key}.lock"
|
||||
self._locks[key] = FileLock(lock_path)
|
||||
return self._locks[key]
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
key: str,
|
||||
timeout: float = 10.0,
|
||||
) -> bool:
|
||||
"""
|
||||
Acquire a lock.
|
||||
|
||||
Args:
|
||||
key: Lock key
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
True if acquired
|
||||
"""
|
||||
lock = self._get_lock(key)
|
||||
try:
|
||||
lock.acquire(timeout=timeout)
|
||||
return True
|
||||
except Timeout:
|
||||
return False
|
||||
|
||||
def release(self, key: str) -> bool:
|
||||
"""
|
||||
Release a lock.
|
||||
|
||||
Args:
|
||||
key: Lock key
|
||||
|
||||
Returns:
|
||||
True if released
|
||||
"""
|
||||
if key in self._locks:
|
||||
try:
|
||||
self._locks[key].release()
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
def is_locked(self, key: str) -> bool:
|
||||
"""Check if a key is locked."""
|
||||
lock = self._get_lock(key)
|
||||
return lock.is_locked
|
||||
Reference in New Issue
Block a user