feat(git-ops): add GitHub provider with auto-detection

Implements GitHub API provider following the same pattern as Gitea:
- Full PR operations (create, get, list, merge, update, close)
- Branch operations via API
- Comment and label management
- Reviewer request support
- Rate limit error handling

Server enhancements:
- Auto-detect provider from repository URL (github.com vs custom Gitea)
- Initialize GitHub provider when token is configured
- Health check includes both provider statuses
- Token selection based on repo URL for clone/push operations

Refs: #110

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-06 20:55:22 +01:00
parent 9dfa76aa41
commit 1779239c07
4 changed files with 1328 additions and 14 deletions

View File

@@ -6,5 +6,6 @@ Provides adapters for different git hosting platforms (Gitea, GitHub, GitLab).
from .base import BaseProvider from .base import BaseProvider
from .gitea import GiteaProvider from .gitea import GiteaProvider
from .github import GitHubProvider
__all__ = ["BaseProvider", "GiteaProvider"] __all__ = ["BaseProvider", "GiteaProvider", "GitHubProvider"]

View File

@@ -0,0 +1,677 @@
"""
GitHub provider implementation.
Implements the BaseProvider interface for GitHub API operations.
"""
import logging
from datetime import UTC, datetime
from typing import Any
import httpx
from config import Settings, get_settings
from exceptions import (
APIError,
AuthenticationError,
PRNotFoundError,
)
from models import (
CreatePRResult,
GetPRResult,
ListPRsResult,
MergePRResult,
MergeStrategy,
PRInfo,
PRState,
UpdatePRResult,
)
from .base import BaseProvider
logger = logging.getLogger(__name__)
class GitHubProvider(BaseProvider):
"""
GitHub API provider implementation.
Supports all PR operations, branch operations, and repository queries.
"""
def __init__(
self,
token: str | None = None,
settings: Settings | None = None,
) -> None:
"""
Initialize GitHub provider.
Args:
token: GitHub personal access token or fine-grained token
settings: Optional settings override
"""
self.settings = settings or get_settings()
self.token = token or self.settings.github_token
self._client: httpx.AsyncClient | None = None
self._user: str | None = None
@property
def name(self) -> str:
"""Return the provider name."""
return "github"
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None:
headers = {
"Accept": "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
}
if self.token:
headers["Authorization"] = f"Bearer {self.token}"
self._client = httpx.AsyncClient(
base_url="https://api.github.com",
headers=headers,
timeout=30.0,
)
return self._client
async def close(self) -> None:
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
async def _request(
self,
method: str,
path: str,
**kwargs: Any,
) -> Any:
"""
Make an API request.
Args:
method: HTTP method
path: API path
**kwargs: Additional request arguments
Returns:
Parsed JSON response
Raises:
APIError: On API errors
AuthenticationError: On auth failures
"""
client = await self._get_client()
try:
response = await client.request(method, path, **kwargs)
if response.status_code == 401:
raise AuthenticationError("github", "Invalid or expired token")
if response.status_code == 403:
# Check for rate limiting
if "rate limit" in response.text.lower():
raise APIError(
"github", 403, "GitHub API rate limit exceeded"
)
raise AuthenticationError(
"github", "Insufficient permissions for this operation"
)
if response.status_code == 404:
return None
if response.status_code >= 400:
error_msg = response.text
try:
error_data = response.json()
error_msg = error_data.get("message", error_msg)
except Exception:
pass
raise APIError("github", response.status_code, error_msg)
if response.status_code == 204:
return None
return response.json()
except httpx.RequestError as e:
raise APIError("github", 0, f"Request failed: {e}")
async def is_connected(self) -> bool:
"""Check if connected to GitHub."""
if not self.token:
return False
try:
result = await self._request("GET", "/user")
return result is not None
except Exception:
return False
async def get_authenticated_user(self) -> str | None:
"""Get the authenticated user's username."""
if self._user:
return self._user
try:
result = await self._request("GET", "/user")
if result:
self._user = result.get("login")
return self._user
except Exception:
pass
return None
# Repository operations
async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]:
"""Get repository information."""
result = await self._request("GET", f"/repos/{owner}/{repo}")
if result is None:
raise APIError("github", 404, f"Repository not found: {owner}/{repo}")
return result
async def get_default_branch(self, owner: str, repo: str) -> str:
"""Get the default branch for a repository."""
repo_info = await self.get_repo_info(owner, repo)
return repo_info.get("default_branch", "main")
# Pull Request operations
async def create_pr(
self,
owner: str,
repo: str,
title: str,
body: str,
source_branch: str,
target_branch: str,
draft: bool = False,
labels: list[str] | None = None,
assignees: list[str] | None = None,
reviewers: list[str] | None = None,
) -> CreatePRResult:
"""Create a pull request."""
try:
data: dict[str, Any] = {
"title": title,
"body": body,
"head": source_branch,
"base": target_branch,
"draft": draft,
}
result = await self._request(
"POST",
f"/repos/{owner}/{repo}/pulls",
json=data,
)
if result is None:
return CreatePRResult(
success=False,
error="Failed to create pull request",
)
pr_number = result["number"]
# Add labels if specified
if labels:
await self.add_labels(owner, repo, pr_number, labels)
# Add assignees if specified
if assignees:
await self._request(
"POST",
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
json={"assignees": assignees},
)
# Request reviewers if specified
if reviewers:
await self.request_review(owner, repo, pr_number, reviewers)
return CreatePRResult(
success=True,
pr_number=pr_number,
pr_url=result.get("html_url"),
)
except APIError as e:
return CreatePRResult(
success=False,
error=str(e),
)
async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult:
"""Get a pull request by number."""
try:
result = await self._request(
"GET",
f"/repos/{owner}/{repo}/pulls/{pr_number}",
)
if result is None:
raise PRNotFoundError(pr_number, f"{owner}/{repo}")
pr_info = self._parse_pr(result)
return GetPRResult(
success=True,
pr=pr_info.to_dict(),
)
except PRNotFoundError:
return GetPRResult(
success=False,
error=f"Pull request #{pr_number} not found",
)
except APIError as e:
return GetPRResult(
success=False,
error=str(e),
)
async def list_prs(
self,
owner: str,
repo: str,
state: PRState | None = None,
author: str | None = None,
limit: int = 20,
) -> ListPRsResult:
"""List pull requests."""
try:
params: dict[str, Any] = {
"per_page": min(limit, 100), # GitHub max is 100
}
if state:
# GitHub uses 'state' for open/closed only
# Merged PRs are closed PRs with merged_at set
if state == PRState.OPEN:
params["state"] = "open"
elif state in (PRState.CLOSED, PRState.MERGED):
params["state"] = "closed"
else:
params["state"] = "all"
result = await self._request(
"GET",
f"/repos/{owner}/{repo}/pulls",
params=params,
)
if result is None:
return ListPRsResult(
success=True,
pull_requests=[],
total_count=0,
)
prs = []
for pr_data in result:
# Filter by author if specified
if author:
pr_author = pr_data.get("user", {}).get("login", "")
if pr_author.lower() != author.lower():
continue
# Filter merged PRs if looking specifically for merged
if state == PRState.MERGED:
if not pr_data.get("merged_at"):
continue
pr_info = self._parse_pr(pr_data)
prs.append(pr_info.to_dict())
return ListPRsResult(
success=True,
pull_requests=prs,
total_count=len(prs),
)
except APIError as e:
return ListPRsResult(
success=False,
error=str(e),
)
async def merge_pr(
self,
owner: str,
repo: str,
pr_number: int,
merge_strategy: MergeStrategy = MergeStrategy.MERGE,
commit_message: str | None = None,
delete_branch: bool = True,
) -> MergePRResult:
"""Merge a pull request."""
try:
# Map merge strategy to GitHub's merge_method values
method_map = {
MergeStrategy.MERGE: "merge",
MergeStrategy.SQUASH: "squash",
MergeStrategy.REBASE: "rebase",
}
data: dict[str, Any] = {
"merge_method": method_map[merge_strategy],
}
if commit_message:
# For squash, commit_title and commit_message
# For merge, commit_title and commit_message
parts = commit_message.split("\n", 1)
data["commit_title"] = parts[0]
if len(parts) > 1:
data["commit_message"] = parts[1]
result = await self._request(
"PUT",
f"/repos/{owner}/{repo}/pulls/{pr_number}/merge",
json=data,
)
if result is None:
return MergePRResult(
success=False,
error="Failed to merge pull request",
)
branch_deleted = False
# Delete branch if requested
if delete_branch and result.get("merged"):
# Get PR to find the branch name
pr_result = await self.get_pr(owner, repo, pr_number)
if pr_result.success and pr_result.pr:
source_branch = pr_result.pr.get("source_branch")
if source_branch:
branch_deleted = await self.delete_remote_branch(
owner, repo, source_branch
)
return MergePRResult(
success=True,
merge_commit_sha=result.get("sha"),
branch_deleted=branch_deleted,
)
except APIError as e:
return MergePRResult(
success=False,
error=str(e),
)
async def update_pr(
self,
owner: str,
repo: str,
pr_number: int,
title: str | None = None,
body: str | None = None,
state: PRState | None = None,
labels: list[str] | None = None,
assignees: list[str] | None = None,
) -> UpdatePRResult:
"""Update a pull request."""
try:
data: dict[str, Any] = {}
if title is not None:
data["title"] = title
if body is not None:
data["body"] = body
if state is not None:
if state == PRState.OPEN:
data["state"] = "open"
elif state == PRState.CLOSED:
data["state"] = "closed"
# Update PR if there's data
if data:
await self._request(
"PATCH",
f"/repos/{owner}/{repo}/pulls/{pr_number}",
json=data,
)
# Update labels via issue endpoint
if labels is not None:
await self._request(
"PUT",
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
json={"labels": labels},
)
# Update assignees via issue endpoint
if assignees is not None:
# First remove all assignees
await self._request(
"DELETE",
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
json={"assignees": []},
)
# Then add new ones
if assignees:
await self._request(
"POST",
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
json={"assignees": assignees},
)
# Fetch updated PR
result = await self.get_pr(owner, repo, pr_number)
return UpdatePRResult(
success=result.success,
pr=result.pr,
error=result.error,
)
except APIError as e:
return UpdatePRResult(
success=False,
error=str(e),
)
async def close_pr(
self,
owner: str,
repo: str,
pr_number: int,
) -> UpdatePRResult:
"""Close a pull request without merging."""
return await self.update_pr(
owner,
repo,
pr_number,
state=PRState.CLOSED,
)
# Branch operations
async def delete_remote_branch(
self,
owner: str,
repo: str,
branch: str,
) -> bool:
"""Delete a remote branch."""
try:
await self._request(
"DELETE",
f"/repos/{owner}/{repo}/git/refs/heads/{branch}",
)
return True
except APIError:
return False
async def get_branch(
self,
owner: str,
repo: str,
branch: str,
) -> dict[str, Any] | None:
"""Get branch information."""
return await self._request(
"GET",
f"/repos/{owner}/{repo}/branches/{branch}",
)
# Comment operations
async def add_pr_comment(
self,
owner: str,
repo: str,
pr_number: int,
body: str,
) -> dict[str, Any]:
"""Add a comment to a pull request."""
result = await self._request(
"POST",
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
json={"body": body},
)
return result or {}
async def list_pr_comments(
self,
owner: str,
repo: str,
pr_number: int,
) -> list[dict[str, Any]]:
"""List comments on a pull request."""
result = await self._request(
"GET",
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
)
return result or []
# Label operations
async def add_labels(
self,
owner: str,
repo: str,
pr_number: int,
labels: list[str],
) -> list[str]:
"""Add labels to a pull request."""
# GitHub creates labels automatically if they don't exist (unlike Gitea)
result = await self._request(
"POST",
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
json={"labels": labels},
)
if result:
return [lbl["name"] for lbl in result]
return labels
async def remove_label(
self,
owner: str,
repo: str,
pr_number: int,
label: str,
) -> list[str]:
"""Remove a label from a pull request."""
await self._request(
"DELETE",
f"/repos/{owner}/{repo}/issues/{pr_number}/labels/{label}",
)
# Return remaining labels
issue = await self._request(
"GET",
f"/repos/{owner}/{repo}/issues/{pr_number}",
)
if issue:
return [lbl["name"] for lbl in issue.get("labels", [])]
return []
# Reviewer operations
async def request_review(
self,
owner: str,
repo: str,
pr_number: int,
reviewers: list[str],
) -> list[str]:
"""Request review from users."""
await self._request(
"POST",
f"/repos/{owner}/{repo}/pulls/{pr_number}/requested_reviewers",
json={"reviewers": reviewers},
)
return reviewers
# Helper methods
def _parse_pr(self, data: dict[str, Any]) -> PRInfo:
"""Parse PR API response into PRInfo."""
# Parse dates
created_at = self._parse_datetime(data.get("created_at"))
updated_at = self._parse_datetime(data.get("updated_at"))
merged_at = self._parse_datetime(data.get("merged_at"))
closed_at = self._parse_datetime(data.get("closed_at"))
# Determine state
if data.get("merged_at"):
state = PRState.MERGED
elif data.get("state") == "closed":
state = PRState.CLOSED
else:
state = PRState.OPEN
# Extract labels
labels = [lbl["name"] for lbl in data.get("labels", [])]
# Extract assignees
assignees = [a["login"] for a in data.get("assignees", [])]
# Extract reviewers
reviewers = []
if "requested_reviewers" in data:
reviewers = [r["login"] for r in data["requested_reviewers"]]
return PRInfo(
number=data["number"],
title=data["title"],
body=data.get("body", "") or "",
state=state,
source_branch=data.get("head", {}).get("ref", ""),
target_branch=data.get("base", {}).get("ref", ""),
author=data.get("user", {}).get("login", ""),
created_at=created_at,
updated_at=updated_at,
merged_at=merged_at,
closed_at=closed_at,
url=data.get("html_url"),
labels=labels,
assignees=assignees,
reviewers=reviewers,
mergeable=data.get("mergeable"),
draft=data.get("draft", False),
)
def _parse_datetime(self, value: str | None) -> datetime:
"""Parse datetime string from API."""
if not value:
return datetime.now(UTC)
try:
# GitHub uses ISO 8601 format with Z suffix
if value.endswith("Z"):
value = value[:-1] + "+00:00"
return datetime.fromisoformat(value)
except ValueError:
return datetime.now(UTC)

View File

@@ -22,7 +22,7 @@ from config import Settings, get_settings
from exceptions import ErrorCode, GitOpsError from exceptions import ErrorCode, GitOpsError
from git_wrapper import GitWrapper from git_wrapper import GitWrapper
from models import MergeStrategy, PRState from models import MergeStrategy, PRState
from providers import GiteaProvider from providers import BaseProvider, GiteaProvider, GitHubProvider
from workspace import WorkspaceManager from workspace import WorkspaceManager
# Input validation patterns # Input validation patterns
@@ -77,25 +77,57 @@ logger = logging.getLogger(__name__)
_settings: Settings | None = None _settings: Settings | None = None
_workspace_manager: WorkspaceManager | None = None _workspace_manager: WorkspaceManager | None = None
_gitea_provider: GiteaProvider | None = None _gitea_provider: GiteaProvider | None = None
_github_provider: GitHubProvider | None = None
def _get_provider_for_url(repo_url: str) -> GiteaProvider | None: def _get_provider_for_url(repo_url: str) -> BaseProvider | None:
"""Get the appropriate provider for a repository URL.""" """Get the appropriate provider for a repository URL."""
if not _settings: if not _settings:
return None return None
# Normalize URL for matching
url_lower = repo_url.lower()
# Check for GitHub URLs
if "github.com" in url_lower or (
_settings.github_api_url
and _settings.github_api_url != "https://api.github.com"
and _settings.github_api_url.replace("https://", "").replace("/api/v3", "") in url_lower
):
return _github_provider
# Check if it's a Gitea URL # Check if it's a Gitea URL
if _settings.gitea_base_url and _settings.gitea_base_url in repo_url: if _settings.gitea_base_url and _settings.gitea_base_url.lower() in url_lower:
return _gitea_provider return _gitea_provider
# Default to Gitea for now # Default: try to detect from URL pattern
# If URL contains 'github' anywhere, use GitHub
if "github" in url_lower:
return _github_provider
# Default to Gitea for self-hosted instances
return _gitea_provider return _gitea_provider
def _get_auth_token_for_url(repo_url: str) -> str | None:
"""Get the appropriate auth token for a repository URL."""
if not _settings:
return None
url_lower = repo_url.lower()
# GitHub token
if "github.com" in url_lower or "github" in url_lower:
return _settings.github_token if _settings.github_token else None
# Gitea token (default)
return _settings.gitea_token if _settings.gitea_token else None
@asynccontextmanager @asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncIterator[None]: async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
"""Application lifespan handler.""" """Application lifespan handler."""
global _settings, _workspace_manager, _gitea_provider global _settings, _workspace_manager, _gitea_provider, _github_provider
logger.info("Starting Git Operations MCP Server...") logger.info("Starting Git Operations MCP Server...")
@@ -114,6 +146,13 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
) )
logger.info(f"Gitea provider initialized: {_settings.gitea_base_url}") logger.info(f"Gitea provider initialized: {_settings.gitea_base_url}")
if _settings.github_token:
_github_provider = GitHubProvider(
token=_settings.github_token,
settings=_settings,
)
logger.info("GitHub provider initialized")
logger.info("Git Operations MCP Server started successfully") logger.info("Git Operations MCP Server started successfully")
yield yield
@@ -124,6 +163,9 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
if _gitea_provider: if _gitea_provider:
await _gitea_provider.close() await _gitea_provider.close()
if _github_provider:
await _github_provider.close()
logger.info("Git Operations MCP Server shut down") logger.info("Git Operations MCP Server shut down")
@@ -167,6 +209,21 @@ async def health_check() -> dict[str, Any]:
else: else:
status["dependencies"]["gitea"] = "not configured" status["dependencies"]["gitea"] = "not configured"
# Check GitHub connectivity
if _github_provider:
try:
if await _github_provider.is_connected():
user = await _github_provider.get_authenticated_user()
status["dependencies"]["github"] = f"connected as {user}"
else:
status["dependencies"]["github"] = "not connected"
is_degraded = True
except Exception as e:
status["dependencies"]["github"] = f"error: {e}"
is_degraded = True
else:
status["dependencies"]["github"] = "not configured"
# Check workspace directory # Check workspace directory
if _workspace_manager: if _workspace_manager:
try: try:
@@ -442,10 +499,8 @@ async def clone_repository(
# Create workspace # Create workspace
workspace = await _workspace_manager.create_workspace(project_id, repo_url) # type: ignore[union-attr] workspace = await _workspace_manager.create_workspace(project_id, repo_url) # type: ignore[union-attr]
# Get auth token from provider # Get auth token appropriate for the repository URL
auth_token = None auth_token = _get_auth_token_for_url(repo_url)
if _settings and _settings.gitea_token:
auth_token = _settings.gitea_token
# Clone repository # Clone repository
git = GitWrapper(workspace.path, _settings) git = GitWrapper(workspace.path, _settings)
@@ -715,10 +770,8 @@ async def push(
if not workspace: if not workspace:
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value} return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value}
# Get auth token # Get auth token appropriate for the repository URL
auth_token = None auth_token = _get_auth_token_for_url(workspace.repo_url or "")
if _settings and _settings.gitea_token:
auth_token = _settings.gitea_token
git = GitWrapper(workspace.path, _settings) git = GitWrapper(workspace.path, _settings)
result = await git.push( result = await git.push(

View File

@@ -0,0 +1,583 @@
"""
Tests for GitHub provider implementation.
"""
from unittest.mock import MagicMock
import pytest
from exceptions import APIError, AuthenticationError
from models import MergeStrategy, PRState
from providers.github import GitHubProvider
class TestGitHubProviderBasics:
"""Tests for GitHubProvider basic operations."""
def test_provider_name(self):
"""Test provider name is github."""
provider = GitHubProvider(token="test-token")
assert provider.name == "github"
def test_parse_repo_url_https(self):
"""Test parsing HTTPS repo URL."""
provider = GitHubProvider(token="test-token")
owner, repo = provider.parse_repo_url("https://github.com/owner/repo.git")
assert owner == "owner"
assert repo == "repo"
def test_parse_repo_url_https_no_git(self):
"""Test parsing HTTPS URL without .git suffix."""
provider = GitHubProvider(token="test-token")
owner, repo = provider.parse_repo_url("https://github.com/owner/repo")
assert owner == "owner"
assert repo == "repo"
def test_parse_repo_url_ssh(self):
"""Test parsing SSH repo URL."""
provider = GitHubProvider(token="test-token")
owner, repo = provider.parse_repo_url("git@github.com:owner/repo.git")
assert owner == "owner"
assert repo == "repo"
def test_parse_repo_url_invalid(self):
"""Test error on invalid URL."""
provider = GitHubProvider(token="test-token")
with pytest.raises(ValueError, match="Unable to parse"):
provider.parse_repo_url("invalid-url")
@pytest.fixture
def mock_github_httpx_client():
"""Create a mock httpx client for GitHub provider tests."""
from unittest.mock import AsyncMock
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = MagicMock(return_value={})
mock_response.text = ""
mock_client = AsyncMock()
mock_client.request = AsyncMock(return_value=mock_response)
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.post = AsyncMock(return_value=mock_response)
mock_client.patch = AsyncMock(return_value=mock_response)
mock_client.put = AsyncMock(return_value=mock_response)
mock_client.delete = AsyncMock(return_value=mock_response)
return mock_client
@pytest.fixture
async def github_provider(test_settings, mock_github_httpx_client):
"""Create a GitHubProvider with mocked HTTP client."""
provider = GitHubProvider(
token=test_settings.github_token,
settings=test_settings,
)
provider._client = mock_github_httpx_client
yield provider
await provider.close()
@pytest.fixture
def github_pr_data():
"""Sample PR data from GitHub API."""
return {
"number": 42,
"title": "Test PR",
"body": "This is a test pull request",
"state": "open",
"head": {"ref": "feature-branch"},
"base": {"ref": "main"},
"user": {"login": "test-user"},
"created_at": "2024-01-15T10:00:00Z",
"updated_at": "2024-01-15T12:00:00Z",
"merged_at": None,
"closed_at": None,
"html_url": "https://github.com/owner/repo/pull/42",
"labels": [{"name": "enhancement"}],
"assignees": [{"login": "assignee1"}],
"requested_reviewers": [{"login": "reviewer1"}],
"mergeable": True,
"draft": False,
}
class TestGitHubProviderConnection:
"""Tests for GitHub provider connection."""
@pytest.mark.asyncio
async def test_is_connected(self, github_provider, mock_github_httpx_client):
"""Test connection check."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={"login": "test-user"}
)
result = await github_provider.is_connected()
assert result is True
@pytest.mark.asyncio
async def test_is_connected_no_token(self, test_settings):
"""Test connection fails without token."""
provider = GitHubProvider(
token="",
settings=test_settings,
)
result = await provider.is_connected()
assert result is False
await provider.close()
@pytest.mark.asyncio
async def test_get_authenticated_user(self, github_provider, mock_github_httpx_client):
"""Test getting authenticated user."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={"login": "test-user"}
)
user = await github_provider.get_authenticated_user()
assert user == "test-user"
class TestGitHubProviderRepoOperations:
"""Tests for GitHub repository operations."""
@pytest.mark.asyncio
async def test_get_repo_info(self, github_provider, mock_github_httpx_client):
"""Test getting repository info."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={
"name": "repo",
"full_name": "owner/repo",
"default_branch": "main",
}
)
result = await github_provider.get_repo_info("owner", "repo")
assert result["name"] == "repo"
assert result["default_branch"] == "main"
@pytest.mark.asyncio
async def test_get_default_branch(self, github_provider, mock_github_httpx_client):
"""Test getting default branch."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={"default_branch": "develop"}
)
branch = await github_provider.get_default_branch("owner", "repo")
assert branch == "develop"
class TestGitHubPROperations:
"""Tests for GitHub PR operations."""
@pytest.mark.asyncio
async def test_create_pr(self, github_provider, mock_github_httpx_client):
"""Test creating a pull request."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={
"number": 42,
"html_url": "https://github.com/owner/repo/pull/42",
}
)
result = await github_provider.create_pr(
owner="owner",
repo="repo",
title="Test PR",
body="Test body",
source_branch="feature",
target_branch="main",
)
assert result.success is True
assert result.pr_number == 42
assert result.pr_url == "https://github.com/owner/repo/pull/42"
@pytest.mark.asyncio
async def test_create_pr_with_draft(self, github_provider, mock_github_httpx_client):
"""Test creating a draft PR."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={
"number": 43,
"html_url": "https://github.com/owner/repo/pull/43",
}
)
result = await github_provider.create_pr(
owner="owner",
repo="repo",
title="Draft PR",
body="Draft body",
source_branch="feature",
target_branch="main",
draft=True,
)
assert result.success is True
assert result.pr_number == 43
@pytest.mark.asyncio
async def test_create_pr_with_options(self, github_provider, mock_github_httpx_client):
"""Test creating PR with labels, assignees, reviewers."""
mock_responses = [
{"number": 44, "html_url": "https://github.com/owner/repo/pull/44"}, # Create PR
[{"name": "enhancement"}], # POST add labels
{}, # POST add assignees
{}, # POST request reviewers
]
mock_github_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses)
result = await github_provider.create_pr(
owner="owner",
repo="repo",
title="Test PR",
body="Test body",
source_branch="feature",
target_branch="main",
labels=["enhancement"],
assignees=["user1"],
reviewers=["reviewer1"],
)
assert result.success is True
@pytest.mark.asyncio
async def test_get_pr(self, github_provider, mock_github_httpx_client, github_pr_data):
"""Test getting a pull request."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=github_pr_data
)
result = await github_provider.get_pr("owner", "repo", 42)
assert result.success is True
assert result.pr["number"] == 42
assert result.pr["title"] == "Test PR"
@pytest.mark.asyncio
async def test_get_pr_not_found(self, github_provider, mock_github_httpx_client):
"""Test getting non-existent PR."""
mock_github_httpx_client.request.return_value.status_code = 404
mock_github_httpx_client.request.return_value.json = MagicMock(return_value=None)
result = await github_provider.get_pr("owner", "repo", 999)
assert result.success is False
@pytest.mark.asyncio
async def test_list_prs(self, github_provider, mock_github_httpx_client, github_pr_data):
"""Test listing pull requests."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=[github_pr_data, github_pr_data]
)
result = await github_provider.list_prs("owner", "repo")
assert result.success is True
assert len(result.pull_requests) == 2
@pytest.mark.asyncio
async def test_list_prs_with_state_filter(self, github_provider, mock_github_httpx_client, github_pr_data):
"""Test listing PRs with state filter."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=[github_pr_data]
)
result = await github_provider.list_prs(
"owner", "repo", state=PRState.OPEN
)
assert result.success is True
@pytest.mark.asyncio
async def test_merge_pr(self, github_provider, mock_github_httpx_client, github_pr_data):
"""Test merging a pull request."""
# Merge returns sha, then get_pr returns the PR data, then delete branch
mock_responses = [
{"sha": "merge-commit-sha", "merged": True}, # PUT merge
github_pr_data, # GET PR for branch info
None, # DELETE branch
]
mock_github_httpx_client.request.return_value.json = MagicMock(
side_effect=mock_responses
)
result = await github_provider.merge_pr(
"owner", "repo", 42,
merge_strategy=MergeStrategy.SQUASH,
)
assert result.success is True
assert result.merge_commit_sha == "merge-commit-sha"
@pytest.mark.asyncio
async def test_merge_pr_rebase(self, github_provider, mock_github_httpx_client, github_pr_data):
"""Test merging with rebase strategy."""
mock_responses = [
{"sha": "rebase-commit-sha", "merged": True}, # PUT merge
github_pr_data, # GET PR for branch info
None, # DELETE branch
]
mock_github_httpx_client.request.return_value.json = MagicMock(
side_effect=mock_responses
)
result = await github_provider.merge_pr(
"owner", "repo", 42,
merge_strategy=MergeStrategy.REBASE,
)
assert result.success is True
@pytest.mark.asyncio
async def test_update_pr(self, github_provider, mock_github_httpx_client, github_pr_data):
"""Test updating a pull request."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=github_pr_data
)
result = await github_provider.update_pr(
"owner", "repo", 42,
title="Updated Title",
body="Updated body",
)
assert result.success is True
@pytest.mark.asyncio
async def test_close_pr(self, github_provider, mock_github_httpx_client, github_pr_data):
"""Test closing a pull request."""
github_pr_data["state"] = "closed"
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=github_pr_data
)
result = await github_provider.close_pr("owner", "repo", 42)
assert result.success is True
class TestGitHubBranchOperations:
"""Tests for GitHub branch operations."""
@pytest.mark.asyncio
async def test_get_branch(self, github_provider, mock_github_httpx_client):
"""Test getting branch info."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={
"name": "main",
"commit": {"sha": "abc123"},
}
)
result = await github_provider.get_branch("owner", "repo", "main")
assert result["name"] == "main"
@pytest.mark.asyncio
async def test_delete_remote_branch(self, github_provider, mock_github_httpx_client):
"""Test deleting a remote branch."""
mock_github_httpx_client.request.return_value.status_code = 204
result = await github_provider.delete_remote_branch("owner", "repo", "old-branch")
assert result is True
class TestGitHubCommentOperations:
"""Tests for GitHub comment operations."""
@pytest.mark.asyncio
async def test_add_pr_comment(self, github_provider, mock_github_httpx_client):
"""Test adding a comment to a PR."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={"id": 1, "body": "Test comment"}
)
result = await github_provider.add_pr_comment(
"owner", "repo", 42, "Test comment"
)
assert result["body"] == "Test comment"
@pytest.mark.asyncio
async def test_list_pr_comments(self, github_provider, mock_github_httpx_client):
"""Test listing PR comments."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=[
{"id": 1, "body": "Comment 1"},
{"id": 2, "body": "Comment 2"},
]
)
result = await github_provider.list_pr_comments("owner", "repo", 42)
assert len(result) == 2
class TestGitHubLabelOperations:
"""Tests for GitHub label operations."""
@pytest.mark.asyncio
async def test_add_labels(self, github_provider, mock_github_httpx_client):
"""Test adding labels to a PR."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=[{"name": "bug"}, {"name": "urgent"}]
)
result = await github_provider.add_labels(
"owner", "repo", 42, ["bug", "urgent"]
)
assert "bug" in result
assert "urgent" in result
@pytest.mark.asyncio
async def test_remove_label(self, github_provider, mock_github_httpx_client):
"""Test removing a label from a PR."""
mock_responses = [
None, # DELETE label
{"labels": []}, # GET issue
]
mock_github_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses)
result = await github_provider.remove_label(
"owner", "repo", 42, "bug"
)
assert isinstance(result, list)
class TestGitHubReviewerOperations:
"""Tests for GitHub reviewer operations."""
@pytest.mark.asyncio
async def test_request_review(self, github_provider, mock_github_httpx_client):
"""Test requesting review from users."""
mock_github_httpx_client.request.return_value.json = MagicMock(return_value={})
result = await github_provider.request_review(
"owner", "repo", 42, ["reviewer1", "reviewer2"]
)
assert result == ["reviewer1", "reviewer2"]
class TestGitHubErrorHandling:
"""Tests for error handling in GitHub provider."""
@pytest.mark.asyncio
async def test_authentication_error(self, github_provider, mock_github_httpx_client):
"""Test handling authentication errors."""
mock_github_httpx_client.request.return_value.status_code = 401
with pytest.raises(AuthenticationError):
await github_provider._request("GET", "/user")
@pytest.mark.asyncio
async def test_permission_denied(self, github_provider, mock_github_httpx_client):
"""Test handling permission denied errors."""
mock_github_httpx_client.request.return_value.status_code = 403
mock_github_httpx_client.request.return_value.text = "Permission denied"
with pytest.raises(AuthenticationError, match="Insufficient permissions"):
await github_provider._request("GET", "/protected")
@pytest.mark.asyncio
async def test_rate_limit_error(self, github_provider, mock_github_httpx_client):
"""Test handling rate limit errors."""
mock_github_httpx_client.request.return_value.status_code = 403
mock_github_httpx_client.request.return_value.text = "API rate limit exceeded"
with pytest.raises(APIError, match="rate limit"):
await github_provider._request("GET", "/user")
@pytest.mark.asyncio
async def test_api_error(self, github_provider, mock_github_httpx_client):
"""Test handling general API errors."""
mock_github_httpx_client.request.return_value.status_code = 500
mock_github_httpx_client.request.return_value.text = "Internal Server Error"
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={"message": "Server error"}
)
with pytest.raises(APIError):
await github_provider._request("GET", "/error")
class TestGitHubPRParsing:
"""Tests for PR data parsing."""
def test_parse_pr_open(self, github_provider, github_pr_data):
"""Test parsing open PR."""
pr_info = github_provider._parse_pr(github_pr_data)
assert pr_info.number == 42
assert pr_info.state == PRState.OPEN
assert pr_info.title == "Test PR"
assert pr_info.source_branch == "feature-branch"
assert pr_info.target_branch == "main"
def test_parse_pr_merged(self, github_provider, github_pr_data):
"""Test parsing merged PR."""
github_pr_data["merged_at"] = "2024-01-16T10:00:00Z"
pr_info = github_provider._parse_pr(github_pr_data)
assert pr_info.state == PRState.MERGED
def test_parse_pr_closed(self, github_provider, github_pr_data):
"""Test parsing closed PR."""
github_pr_data["state"] = "closed"
github_pr_data["closed_at"] = "2024-01-16T10:00:00Z"
pr_info = github_provider._parse_pr(github_pr_data)
assert pr_info.state == PRState.CLOSED
def test_parse_pr_draft(self, github_provider, github_pr_data):
"""Test parsing draft PR."""
github_pr_data["draft"] = True
pr_info = github_provider._parse_pr(github_pr_data)
assert pr_info.draft is True
def test_parse_datetime_iso(self, github_provider):
"""Test parsing ISO datetime strings."""
dt = github_provider._parse_datetime("2024-01-15T10:30:00Z")
assert dt.year == 2024
assert dt.month == 1
assert dt.day == 15
def test_parse_datetime_none(self, github_provider):
"""Test parsing None datetime returns now."""
dt = github_provider._parse_datetime(None)
assert dt is not None
assert dt.tzinfo is not None
def test_parse_pr_with_null_body(self, github_provider, github_pr_data):
"""Test parsing PR with null body."""
github_pr_data["body"] = None
pr_info = github_provider._parse_pr(github_pr_data)
assert pr_info.body == ""