forked from cardosofelipe/fast-next-template
feat(git-ops): add GitHub provider with auto-detection
Implements GitHub API provider following the same pattern as Gitea: - Full PR operations (create, get, list, merge, update, close) - Branch operations via API - Comment and label management - Reviewer request support - Rate limit error handling Server enhancements: - Auto-detect provider from repository URL (github.com vs custom Gitea) - Initialize GitHub provider when token is configured - Health check includes both provider statuses - Token selection based on repo URL for clone/push operations Refs: #110 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -6,5 +6,6 @@ Provides adapters for different git hosting platforms (Gitea, GitHub, GitLab).
|
||||
|
||||
from .base import BaseProvider
|
||||
from .gitea import GiteaProvider
|
||||
from .github import GitHubProvider
|
||||
|
||||
__all__ = ["BaseProvider", "GiteaProvider"]
|
||||
__all__ = ["BaseProvider", "GiteaProvider", "GitHubProvider"]
|
||||
|
||||
677
mcp-servers/git-ops/providers/github.py
Normal file
677
mcp-servers/git-ops/providers/github.py
Normal file
@@ -0,0 +1,677 @@
|
||||
"""
|
||||
GitHub provider implementation.
|
||||
|
||||
Implements the BaseProvider interface for GitHub API operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from config import Settings, get_settings
|
||||
from exceptions import (
|
||||
APIError,
|
||||
AuthenticationError,
|
||||
PRNotFoundError,
|
||||
)
|
||||
from models import (
|
||||
CreatePRResult,
|
||||
GetPRResult,
|
||||
ListPRsResult,
|
||||
MergePRResult,
|
||||
MergeStrategy,
|
||||
PRInfo,
|
||||
PRState,
|
||||
UpdatePRResult,
|
||||
)
|
||||
|
||||
from .base import BaseProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GitHubProvider(BaseProvider):
|
||||
"""
|
||||
GitHub API provider implementation.
|
||||
|
||||
Supports all PR operations, branch operations, and repository queries.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token: str | None = None,
|
||||
settings: Settings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize GitHub provider.
|
||||
|
||||
Args:
|
||||
token: GitHub personal access token or fine-grained token
|
||||
settings: Optional settings override
|
||||
"""
|
||||
self.settings = settings or get_settings()
|
||||
self.token = token or self.settings.github_token
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._user: str | None = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "github"
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create HTTP client."""
|
||||
if self._client is None:
|
||||
headers = {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
}
|
||||
if self.token:
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url="https://api.github.com",
|
||||
headers=headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Make an API request.
|
||||
|
||||
Args:
|
||||
method: HTTP method
|
||||
path: API path
|
||||
**kwargs: Additional request arguments
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
|
||||
Raises:
|
||||
APIError: On API errors
|
||||
AuthenticationError: On auth failures
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
response = await client.request(method, path, **kwargs)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise AuthenticationError("github", "Invalid or expired token")
|
||||
|
||||
if response.status_code == 403:
|
||||
# Check for rate limiting
|
||||
if "rate limit" in response.text.lower():
|
||||
raise APIError(
|
||||
"github", 403, "GitHub API rate limit exceeded"
|
||||
)
|
||||
raise AuthenticationError(
|
||||
"github", "Insufficient permissions for this operation"
|
||||
)
|
||||
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
|
||||
if response.status_code >= 400:
|
||||
error_msg = response.text
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = error_data.get("message", error_msg)
|
||||
except Exception:
|
||||
pass
|
||||
raise APIError("github", response.status_code, error_msg)
|
||||
|
||||
if response.status_code == 204:
|
||||
return None
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
raise APIError("github", 0, f"Request failed: {e}")
|
||||
|
||||
async def is_connected(self) -> bool:
|
||||
"""Check if connected to GitHub."""
|
||||
if not self.token:
|
||||
return False
|
||||
|
||||
try:
|
||||
result = await self._request("GET", "/user")
|
||||
return result is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_authenticated_user(self) -> str | None:
|
||||
"""Get the authenticated user's username."""
|
||||
if self._user:
|
||||
return self._user
|
||||
|
||||
try:
|
||||
result = await self._request("GET", "/user")
|
||||
if result:
|
||||
self._user = result.get("login")
|
||||
return self._user
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
# Repository operations
|
||||
|
||||
async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]:
|
||||
"""Get repository information."""
|
||||
result = await self._request("GET", f"/repos/{owner}/{repo}")
|
||||
if result is None:
|
||||
raise APIError("github", 404, f"Repository not found: {owner}/{repo}")
|
||||
return result
|
||||
|
||||
async def get_default_branch(self, owner: str, repo: str) -> str:
|
||||
"""Get the default branch for a repository."""
|
||||
repo_info = await self.get_repo_info(owner, repo)
|
||||
return repo_info.get("default_branch", "main")
|
||||
|
||||
# Pull Request operations
|
||||
|
||||
async def create_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
title: str,
|
||||
body: str,
|
||||
source_branch: str,
|
||||
target_branch: str,
|
||||
draft: bool = False,
|
||||
labels: list[str] | None = None,
|
||||
assignees: list[str] | None = None,
|
||||
reviewers: list[str] | None = None,
|
||||
) -> CreatePRResult:
|
||||
"""Create a pull request."""
|
||||
try:
|
||||
data: dict[str, Any] = {
|
||||
"title": title,
|
||||
"body": body,
|
||||
"head": source_branch,
|
||||
"base": target_branch,
|
||||
"draft": draft,
|
||||
}
|
||||
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/pulls",
|
||||
json=data,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return CreatePRResult(
|
||||
success=False,
|
||||
error="Failed to create pull request",
|
||||
)
|
||||
|
||||
pr_number = result["number"]
|
||||
|
||||
# Add labels if specified
|
||||
if labels:
|
||||
await self.add_labels(owner, repo, pr_number, labels)
|
||||
|
||||
# Add assignees if specified
|
||||
if assignees:
|
||||
await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
|
||||
json={"assignees": assignees},
|
||||
)
|
||||
|
||||
# Request reviewers if specified
|
||||
if reviewers:
|
||||
await self.request_review(owner, repo, pr_number, reviewers)
|
||||
|
||||
return CreatePRResult(
|
||||
success=True,
|
||||
pr_number=pr_number,
|
||||
pr_url=result.get("html_url"),
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
return CreatePRResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult:
|
||||
"""Get a pull request by number."""
|
||||
try:
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/pulls/{pr_number}",
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise PRNotFoundError(pr_number, f"{owner}/{repo}")
|
||||
|
||||
pr_info = self._parse_pr(result)
|
||||
|
||||
return GetPRResult(
|
||||
success=True,
|
||||
pr=pr_info.to_dict(),
|
||||
)
|
||||
|
||||
except PRNotFoundError:
|
||||
return GetPRResult(
|
||||
success=False,
|
||||
error=f"Pull request #{pr_number} not found",
|
||||
)
|
||||
except APIError as e:
|
||||
return GetPRResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def list_prs(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
state: PRState | None = None,
|
||||
author: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ListPRsResult:
|
||||
"""List pull requests."""
|
||||
try:
|
||||
params: dict[str, Any] = {
|
||||
"per_page": min(limit, 100), # GitHub max is 100
|
||||
}
|
||||
|
||||
if state:
|
||||
# GitHub uses 'state' for open/closed only
|
||||
# Merged PRs are closed PRs with merged_at set
|
||||
if state == PRState.OPEN:
|
||||
params["state"] = "open"
|
||||
elif state in (PRState.CLOSED, PRState.MERGED):
|
||||
params["state"] = "closed"
|
||||
else:
|
||||
params["state"] = "all"
|
||||
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/pulls",
|
||||
params=params,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return ListPRsResult(
|
||||
success=True,
|
||||
pull_requests=[],
|
||||
total_count=0,
|
||||
)
|
||||
|
||||
prs = []
|
||||
for pr_data in result:
|
||||
# Filter by author if specified
|
||||
if author:
|
||||
pr_author = pr_data.get("user", {}).get("login", "")
|
||||
if pr_author.lower() != author.lower():
|
||||
continue
|
||||
|
||||
# Filter merged PRs if looking specifically for merged
|
||||
if state == PRState.MERGED:
|
||||
if not pr_data.get("merged_at"):
|
||||
continue
|
||||
|
||||
pr_info = self._parse_pr(pr_data)
|
||||
prs.append(pr_info.to_dict())
|
||||
|
||||
return ListPRsResult(
|
||||
success=True,
|
||||
pull_requests=prs,
|
||||
total_count=len(prs),
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
return ListPRsResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def merge_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
merge_strategy: MergeStrategy = MergeStrategy.MERGE,
|
||||
commit_message: str | None = None,
|
||||
delete_branch: bool = True,
|
||||
) -> MergePRResult:
|
||||
"""Merge a pull request."""
|
||||
try:
|
||||
# Map merge strategy to GitHub's merge_method values
|
||||
method_map = {
|
||||
MergeStrategy.MERGE: "merge",
|
||||
MergeStrategy.SQUASH: "squash",
|
||||
MergeStrategy.REBASE: "rebase",
|
||||
}
|
||||
|
||||
data: dict[str, Any] = {
|
||||
"merge_method": method_map[merge_strategy],
|
||||
}
|
||||
|
||||
if commit_message:
|
||||
# For squash, commit_title and commit_message
|
||||
# For merge, commit_title and commit_message
|
||||
parts = commit_message.split("\n", 1)
|
||||
data["commit_title"] = parts[0]
|
||||
if len(parts) > 1:
|
||||
data["commit_message"] = parts[1]
|
||||
|
||||
result = await self._request(
|
||||
"PUT",
|
||||
f"/repos/{owner}/{repo}/pulls/{pr_number}/merge",
|
||||
json=data,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return MergePRResult(
|
||||
success=False,
|
||||
error="Failed to merge pull request",
|
||||
)
|
||||
|
||||
branch_deleted = False
|
||||
# Delete branch if requested
|
||||
if delete_branch and result.get("merged"):
|
||||
# Get PR to find the branch name
|
||||
pr_result = await self.get_pr(owner, repo, pr_number)
|
||||
if pr_result.success and pr_result.pr:
|
||||
source_branch = pr_result.pr.get("source_branch")
|
||||
if source_branch:
|
||||
branch_deleted = await self.delete_remote_branch(
|
||||
owner, repo, source_branch
|
||||
)
|
||||
|
||||
return MergePRResult(
|
||||
success=True,
|
||||
merge_commit_sha=result.get("sha"),
|
||||
branch_deleted=branch_deleted,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
return MergePRResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def update_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
title: str | None = None,
|
||||
body: str | None = None,
|
||||
state: PRState | None = None,
|
||||
labels: list[str] | None = None,
|
||||
assignees: list[str] | None = None,
|
||||
) -> UpdatePRResult:
|
||||
"""Update a pull request."""
|
||||
try:
|
||||
data: dict[str, Any] = {}
|
||||
|
||||
if title is not None:
|
||||
data["title"] = title
|
||||
if body is not None:
|
||||
data["body"] = body
|
||||
if state is not None:
|
||||
if state == PRState.OPEN:
|
||||
data["state"] = "open"
|
||||
elif state == PRState.CLOSED:
|
||||
data["state"] = "closed"
|
||||
|
||||
# Update PR if there's data
|
||||
if data:
|
||||
await self._request(
|
||||
"PATCH",
|
||||
f"/repos/{owner}/{repo}/pulls/{pr_number}",
|
||||
json=data,
|
||||
)
|
||||
|
||||
# Update labels via issue endpoint
|
||||
if labels is not None:
|
||||
await self._request(
|
||||
"PUT",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
|
||||
json={"labels": labels},
|
||||
)
|
||||
|
||||
# Update assignees via issue endpoint
|
||||
if assignees is not None:
|
||||
# First remove all assignees
|
||||
await self._request(
|
||||
"DELETE",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
|
||||
json={"assignees": []},
|
||||
)
|
||||
# Then add new ones
|
||||
if assignees:
|
||||
await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
|
||||
json={"assignees": assignees},
|
||||
)
|
||||
|
||||
# Fetch updated PR
|
||||
result = await self.get_pr(owner, repo, pr_number)
|
||||
return UpdatePRResult(
|
||||
success=result.success,
|
||||
pr=result.pr,
|
||||
error=result.error,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
return UpdatePRResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def close_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
) -> UpdatePRResult:
|
||||
"""Close a pull request without merging."""
|
||||
return await self.update_pr(
|
||||
owner,
|
||||
repo,
|
||||
pr_number,
|
||||
state=PRState.CLOSED,
|
||||
)
|
||||
|
||||
# Branch operations
|
||||
|
||||
async def delete_remote_branch(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
branch: str,
|
||||
) -> bool:
|
||||
"""Delete a remote branch."""
|
||||
try:
|
||||
await self._request(
|
||||
"DELETE",
|
||||
f"/repos/{owner}/{repo}/git/refs/heads/{branch}",
|
||||
)
|
||||
return True
|
||||
except APIError:
|
||||
return False
|
||||
|
||||
async def get_branch(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
branch: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Get branch information."""
|
||||
return await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/branches/{branch}",
|
||||
)
|
||||
|
||||
# Comment operations
|
||||
|
||||
async def add_pr_comment(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
body: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Add a comment to a pull request."""
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
|
||||
json={"body": body},
|
||||
)
|
||||
return result or {}
|
||||
|
||||
async def list_pr_comments(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List comments on a pull request."""
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
|
||||
)
|
||||
return result or []
|
||||
|
||||
# Label operations
|
||||
|
||||
async def add_labels(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
labels: list[str],
|
||||
) -> list[str]:
|
||||
"""Add labels to a pull request."""
|
||||
# GitHub creates labels automatically if they don't exist (unlike Gitea)
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
|
||||
json={"labels": labels},
|
||||
)
|
||||
|
||||
if result:
|
||||
return [lbl["name"] for lbl in result]
|
||||
return labels
|
||||
|
||||
async def remove_label(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
label: str,
|
||||
) -> list[str]:
|
||||
"""Remove a label from a pull request."""
|
||||
await self._request(
|
||||
"DELETE",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/labels/{label}",
|
||||
)
|
||||
|
||||
# Return remaining labels
|
||||
issue = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}",
|
||||
)
|
||||
if issue:
|
||||
return [lbl["name"] for lbl in issue.get("labels", [])]
|
||||
return []
|
||||
|
||||
# Reviewer operations
|
||||
|
||||
async def request_review(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
reviewers: list[str],
|
||||
) -> list[str]:
|
||||
"""Request review from users."""
|
||||
await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/pulls/{pr_number}/requested_reviewers",
|
||||
json={"reviewers": reviewers},
|
||||
)
|
||||
return reviewers
|
||||
|
||||
# Helper methods
|
||||
|
||||
def _parse_pr(self, data: dict[str, Any]) -> PRInfo:
|
||||
"""Parse PR API response into PRInfo."""
|
||||
# Parse dates
|
||||
created_at = self._parse_datetime(data.get("created_at"))
|
||||
updated_at = self._parse_datetime(data.get("updated_at"))
|
||||
merged_at = self._parse_datetime(data.get("merged_at"))
|
||||
closed_at = self._parse_datetime(data.get("closed_at"))
|
||||
|
||||
# Determine state
|
||||
if data.get("merged_at"):
|
||||
state = PRState.MERGED
|
||||
elif data.get("state") == "closed":
|
||||
state = PRState.CLOSED
|
||||
else:
|
||||
state = PRState.OPEN
|
||||
|
||||
# Extract labels
|
||||
labels = [lbl["name"] for lbl in data.get("labels", [])]
|
||||
|
||||
# Extract assignees
|
||||
assignees = [a["login"] for a in data.get("assignees", [])]
|
||||
|
||||
# Extract reviewers
|
||||
reviewers = []
|
||||
if "requested_reviewers" in data:
|
||||
reviewers = [r["login"] for r in data["requested_reviewers"]]
|
||||
|
||||
return PRInfo(
|
||||
number=data["number"],
|
||||
title=data["title"],
|
||||
body=data.get("body", "") or "",
|
||||
state=state,
|
||||
source_branch=data.get("head", {}).get("ref", ""),
|
||||
target_branch=data.get("base", {}).get("ref", ""),
|
||||
author=data.get("user", {}).get("login", ""),
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
merged_at=merged_at,
|
||||
closed_at=closed_at,
|
||||
url=data.get("html_url"),
|
||||
labels=labels,
|
||||
assignees=assignees,
|
||||
reviewers=reviewers,
|
||||
mergeable=data.get("mergeable"),
|
||||
draft=data.get("draft", False),
|
||||
)
|
||||
|
||||
def _parse_datetime(self, value: str | None) -> datetime:
|
||||
"""Parse datetime string from API."""
|
||||
if not value:
|
||||
return datetime.now(UTC)
|
||||
|
||||
try:
|
||||
# GitHub uses ISO 8601 format with Z suffix
|
||||
if value.endswith("Z"):
|
||||
value = value[:-1] + "+00:00"
|
||||
return datetime.fromisoformat(value)
|
||||
except ValueError:
|
||||
return datetime.now(UTC)
|
||||
@@ -22,7 +22,7 @@ from config import Settings, get_settings
|
||||
from exceptions import ErrorCode, GitOpsError
|
||||
from git_wrapper import GitWrapper
|
||||
from models import MergeStrategy, PRState
|
||||
from providers import GiteaProvider
|
||||
from providers import BaseProvider, GiteaProvider, GitHubProvider
|
||||
from workspace import WorkspaceManager
|
||||
|
||||
# Input validation patterns
|
||||
@@ -77,25 +77,57 @@ logger = logging.getLogger(__name__)
|
||||
_settings: Settings | None = None
|
||||
_workspace_manager: WorkspaceManager | None = None
|
||||
_gitea_provider: GiteaProvider | None = None
|
||||
_github_provider: GitHubProvider | None = None
|
||||
|
||||
|
||||
def _get_provider_for_url(repo_url: str) -> GiteaProvider | None:
|
||||
def _get_provider_for_url(repo_url: str) -> BaseProvider | None:
|
||||
"""Get the appropriate provider for a repository URL."""
|
||||
if not _settings:
|
||||
return None
|
||||
|
||||
# Normalize URL for matching
|
||||
url_lower = repo_url.lower()
|
||||
|
||||
# Check for GitHub URLs
|
||||
if "github.com" in url_lower or (
|
||||
_settings.github_api_url
|
||||
and _settings.github_api_url != "https://api.github.com"
|
||||
and _settings.github_api_url.replace("https://", "").replace("/api/v3", "") in url_lower
|
||||
):
|
||||
return _github_provider
|
||||
|
||||
# Check if it's a Gitea URL
|
||||
if _settings.gitea_base_url and _settings.gitea_base_url in repo_url:
|
||||
if _settings.gitea_base_url and _settings.gitea_base_url.lower() in url_lower:
|
||||
return _gitea_provider
|
||||
|
||||
# Default to Gitea for now
|
||||
# Default: try to detect from URL pattern
|
||||
# If URL contains 'github' anywhere, use GitHub
|
||||
if "github" in url_lower:
|
||||
return _github_provider
|
||||
|
||||
# Default to Gitea for self-hosted instances
|
||||
return _gitea_provider
|
||||
|
||||
|
||||
def _get_auth_token_for_url(repo_url: str) -> str | None:
|
||||
"""Get the appropriate auth token for a repository URL."""
|
||||
if not _settings:
|
||||
return None
|
||||
|
||||
url_lower = repo_url.lower()
|
||||
|
||||
# GitHub token
|
||||
if "github.com" in url_lower or "github" in url_lower:
|
||||
return _settings.github_token if _settings.github_token else None
|
||||
|
||||
# Gitea token (default)
|
||||
return _settings.gitea_token if _settings.gitea_token else None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
||||
"""Application lifespan handler."""
|
||||
global _settings, _workspace_manager, _gitea_provider
|
||||
global _settings, _workspace_manager, _gitea_provider, _github_provider
|
||||
|
||||
logger.info("Starting Git Operations MCP Server...")
|
||||
|
||||
@@ -114,6 +146,13 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
||||
)
|
||||
logger.info(f"Gitea provider initialized: {_settings.gitea_base_url}")
|
||||
|
||||
if _settings.github_token:
|
||||
_github_provider = GitHubProvider(
|
||||
token=_settings.github_token,
|
||||
settings=_settings,
|
||||
)
|
||||
logger.info("GitHub provider initialized")
|
||||
|
||||
logger.info("Git Operations MCP Server started successfully")
|
||||
|
||||
yield
|
||||
@@ -124,6 +163,9 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
||||
if _gitea_provider:
|
||||
await _gitea_provider.close()
|
||||
|
||||
if _github_provider:
|
||||
await _github_provider.close()
|
||||
|
||||
logger.info("Git Operations MCP Server shut down")
|
||||
|
||||
|
||||
@@ -167,6 +209,21 @@ async def health_check() -> dict[str, Any]:
|
||||
else:
|
||||
status["dependencies"]["gitea"] = "not configured"
|
||||
|
||||
# Check GitHub connectivity
|
||||
if _github_provider:
|
||||
try:
|
||||
if await _github_provider.is_connected():
|
||||
user = await _github_provider.get_authenticated_user()
|
||||
status["dependencies"]["github"] = f"connected as {user}"
|
||||
else:
|
||||
status["dependencies"]["github"] = "not connected"
|
||||
is_degraded = True
|
||||
except Exception as e:
|
||||
status["dependencies"]["github"] = f"error: {e}"
|
||||
is_degraded = True
|
||||
else:
|
||||
status["dependencies"]["github"] = "not configured"
|
||||
|
||||
# Check workspace directory
|
||||
if _workspace_manager:
|
||||
try:
|
||||
@@ -442,10 +499,8 @@ async def clone_repository(
|
||||
# Create workspace
|
||||
workspace = await _workspace_manager.create_workspace(project_id, repo_url) # type: ignore[union-attr]
|
||||
|
||||
# Get auth token from provider
|
||||
auth_token = None
|
||||
if _settings and _settings.gitea_token:
|
||||
auth_token = _settings.gitea_token
|
||||
# Get auth token appropriate for the repository URL
|
||||
auth_token = _get_auth_token_for_url(repo_url)
|
||||
|
||||
# Clone repository
|
||||
git = GitWrapper(workspace.path, _settings)
|
||||
@@ -715,10 +770,8 @@ async def push(
|
||||
if not workspace:
|
||||
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value}
|
||||
|
||||
# Get auth token
|
||||
auth_token = None
|
||||
if _settings and _settings.gitea_token:
|
||||
auth_token = _settings.gitea_token
|
||||
# Get auth token appropriate for the repository URL
|
||||
auth_token = _get_auth_token_for_url(workspace.repo_url or "")
|
||||
|
||||
git = GitWrapper(workspace.path, _settings)
|
||||
result = await git.push(
|
||||
|
||||
583
mcp-servers/git-ops/tests/test_github_provider.py
Normal file
583
mcp-servers/git-ops/tests/test_github_provider.py
Normal file
@@ -0,0 +1,583 @@
|
||||
"""
|
||||
Tests for GitHub provider implementation.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from exceptions import APIError, AuthenticationError
|
||||
from models import MergeStrategy, PRState
|
||||
from providers.github import GitHubProvider
|
||||
|
||||
|
||||
class TestGitHubProviderBasics:
|
||||
"""Tests for GitHubProvider basic operations."""
|
||||
|
||||
def test_provider_name(self):
|
||||
"""Test provider name is github."""
|
||||
provider = GitHubProvider(token="test-token")
|
||||
assert provider.name == "github"
|
||||
|
||||
def test_parse_repo_url_https(self):
|
||||
"""Test parsing HTTPS repo URL."""
|
||||
provider = GitHubProvider(token="test-token")
|
||||
|
||||
owner, repo = provider.parse_repo_url("https://github.com/owner/repo.git")
|
||||
|
||||
assert owner == "owner"
|
||||
assert repo == "repo"
|
||||
|
||||
def test_parse_repo_url_https_no_git(self):
|
||||
"""Test parsing HTTPS URL without .git suffix."""
|
||||
provider = GitHubProvider(token="test-token")
|
||||
|
||||
owner, repo = provider.parse_repo_url("https://github.com/owner/repo")
|
||||
|
||||
assert owner == "owner"
|
||||
assert repo == "repo"
|
||||
|
||||
def test_parse_repo_url_ssh(self):
|
||||
"""Test parsing SSH repo URL."""
|
||||
provider = GitHubProvider(token="test-token")
|
||||
|
||||
owner, repo = provider.parse_repo_url("git@github.com:owner/repo.git")
|
||||
|
||||
assert owner == "owner"
|
||||
assert repo == "repo"
|
||||
|
||||
def test_parse_repo_url_invalid(self):
|
||||
"""Test error on invalid URL."""
|
||||
provider = GitHubProvider(token="test-token")
|
||||
|
||||
with pytest.raises(ValueError, match="Unable to parse"):
|
||||
provider.parse_repo_url("invalid-url")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_github_httpx_client():
|
||||
"""Create a mock httpx client for GitHub provider tests."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = MagicMock(return_value={})
|
||||
mock_response.text = ""
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request = AsyncMock(return_value=mock_response)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.patch = AsyncMock(return_value=mock_response)
|
||||
mock_client.put = AsyncMock(return_value=mock_response)
|
||||
mock_client.delete = AsyncMock(return_value=mock_response)
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def github_provider(test_settings, mock_github_httpx_client):
|
||||
"""Create a GitHubProvider with mocked HTTP client."""
|
||||
provider = GitHubProvider(
|
||||
token=test_settings.github_token,
|
||||
settings=test_settings,
|
||||
)
|
||||
provider._client = mock_github_httpx_client
|
||||
|
||||
yield provider
|
||||
|
||||
await provider.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_pr_data():
|
||||
"""Sample PR data from GitHub API."""
|
||||
return {
|
||||
"number": 42,
|
||||
"title": "Test PR",
|
||||
"body": "This is a test pull request",
|
||||
"state": "open",
|
||||
"head": {"ref": "feature-branch"},
|
||||
"base": {"ref": "main"},
|
||||
"user": {"login": "test-user"},
|
||||
"created_at": "2024-01-15T10:00:00Z",
|
||||
"updated_at": "2024-01-15T12:00:00Z",
|
||||
"merged_at": None,
|
||||
"closed_at": None,
|
||||
"html_url": "https://github.com/owner/repo/pull/42",
|
||||
"labels": [{"name": "enhancement"}],
|
||||
"assignees": [{"login": "assignee1"}],
|
||||
"requested_reviewers": [{"login": "reviewer1"}],
|
||||
"mergeable": True,
|
||||
"draft": False,
|
||||
}
|
||||
|
||||
|
||||
class TestGitHubProviderConnection:
|
||||
"""Tests for GitHub provider connection."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_connected(self, github_provider, mock_github_httpx_client):
|
||||
"""Test connection check."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"login": "test-user"}
|
||||
)
|
||||
|
||||
result = await github_provider.is_connected()
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_connected_no_token(self, test_settings):
|
||||
"""Test connection fails without token."""
|
||||
provider = GitHubProvider(
|
||||
token="",
|
||||
settings=test_settings,
|
||||
)
|
||||
|
||||
result = await provider.is_connected()
|
||||
assert result is False
|
||||
|
||||
await provider.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_authenticated_user(self, github_provider, mock_github_httpx_client):
|
||||
"""Test getting authenticated user."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"login": "test-user"}
|
||||
)
|
||||
|
||||
user = await github_provider.get_authenticated_user()
|
||||
|
||||
assert user == "test-user"
|
||||
|
||||
|
||||
class TestGitHubProviderRepoOperations:
|
||||
"""Tests for GitHub repository operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_repo_info(self, github_provider, mock_github_httpx_client):
|
||||
"""Test getting repository info."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={
|
||||
"name": "repo",
|
||||
"full_name": "owner/repo",
|
||||
"default_branch": "main",
|
||||
}
|
||||
)
|
||||
|
||||
result = await github_provider.get_repo_info("owner", "repo")
|
||||
|
||||
assert result["name"] == "repo"
|
||||
assert result["default_branch"] == "main"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_default_branch(self, github_provider, mock_github_httpx_client):
|
||||
"""Test getting default branch."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"default_branch": "develop"}
|
||||
)
|
||||
|
||||
branch = await github_provider.get_default_branch("owner", "repo")
|
||||
|
||||
assert branch == "develop"
|
||||
|
||||
|
||||
class TestGitHubPROperations:
|
||||
"""Tests for GitHub PR operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pr(self, github_provider, mock_github_httpx_client):
|
||||
"""Test creating a pull request."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={
|
||||
"number": 42,
|
||||
"html_url": "https://github.com/owner/repo/pull/42",
|
||||
}
|
||||
)
|
||||
|
||||
result = await github_provider.create_pr(
|
||||
owner="owner",
|
||||
repo="repo",
|
||||
title="Test PR",
|
||||
body="Test body",
|
||||
source_branch="feature",
|
||||
target_branch="main",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.pr_number == 42
|
||||
assert result.pr_url == "https://github.com/owner/repo/pull/42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pr_with_draft(self, github_provider, mock_github_httpx_client):
|
||||
"""Test creating a draft PR."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={
|
||||
"number": 43,
|
||||
"html_url": "https://github.com/owner/repo/pull/43",
|
||||
}
|
||||
)
|
||||
|
||||
result = await github_provider.create_pr(
|
||||
owner="owner",
|
||||
repo="repo",
|
||||
title="Draft PR",
|
||||
body="Draft body",
|
||||
source_branch="feature",
|
||||
target_branch="main",
|
||||
draft=True,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.pr_number == 43
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pr_with_options(self, github_provider, mock_github_httpx_client):
|
||||
"""Test creating PR with labels, assignees, reviewers."""
|
||||
mock_responses = [
|
||||
{"number": 44, "html_url": "https://github.com/owner/repo/pull/44"}, # Create PR
|
||||
[{"name": "enhancement"}], # POST add labels
|
||||
{}, # POST add assignees
|
||||
{}, # POST request reviewers
|
||||
]
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses)
|
||||
|
||||
result = await github_provider.create_pr(
|
||||
owner="owner",
|
||||
repo="repo",
|
||||
title="Test PR",
|
||||
body="Test body",
|
||||
source_branch="feature",
|
||||
target_branch="main",
|
||||
labels=["enhancement"],
|
||||
assignees=["user1"],
|
||||
reviewers=["reviewer1"],
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pr(self, github_provider, mock_github_httpx_client, github_pr_data):
|
||||
"""Test getting a pull request."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=github_pr_data
|
||||
)
|
||||
|
||||
result = await github_provider.get_pr("owner", "repo", 42)
|
||||
|
||||
assert result.success is True
|
||||
assert result.pr["number"] == 42
|
||||
assert result.pr["title"] == "Test PR"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pr_not_found(self, github_provider, mock_github_httpx_client):
|
||||
"""Test getting non-existent PR."""
|
||||
mock_github_httpx_client.request.return_value.status_code = 404
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(return_value=None)
|
||||
|
||||
result = await github_provider.get_pr("owner", "repo", 999)
|
||||
|
||||
assert result.success is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_prs(self, github_provider, mock_github_httpx_client, github_pr_data):
|
||||
"""Test listing pull requests."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=[github_pr_data, github_pr_data]
|
||||
)
|
||||
|
||||
result = await github_provider.list_prs("owner", "repo")
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.pull_requests) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_prs_with_state_filter(self, github_provider, mock_github_httpx_client, github_pr_data):
|
||||
"""Test listing PRs with state filter."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=[github_pr_data]
|
||||
)
|
||||
|
||||
result = await github_provider.list_prs(
|
||||
"owner", "repo", state=PRState.OPEN
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_pr(self, github_provider, mock_github_httpx_client, github_pr_data):
|
||||
"""Test merging a pull request."""
|
||||
# Merge returns sha, then get_pr returns the PR data, then delete branch
|
||||
mock_responses = [
|
||||
{"sha": "merge-commit-sha", "merged": True}, # PUT merge
|
||||
github_pr_data, # GET PR for branch info
|
||||
None, # DELETE branch
|
||||
]
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
side_effect=mock_responses
|
||||
)
|
||||
|
||||
result = await github_provider.merge_pr(
|
||||
"owner", "repo", 42,
|
||||
merge_strategy=MergeStrategy.SQUASH,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.merge_commit_sha == "merge-commit-sha"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_pr_rebase(self, github_provider, mock_github_httpx_client, github_pr_data):
|
||||
"""Test merging with rebase strategy."""
|
||||
mock_responses = [
|
||||
{"sha": "rebase-commit-sha", "merged": True}, # PUT merge
|
||||
github_pr_data, # GET PR for branch info
|
||||
None, # DELETE branch
|
||||
]
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
side_effect=mock_responses
|
||||
)
|
||||
|
||||
result = await github_provider.merge_pr(
|
||||
"owner", "repo", 42,
|
||||
merge_strategy=MergeStrategy.REBASE,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_pr(self, github_provider, mock_github_httpx_client, github_pr_data):
|
||||
"""Test updating a pull request."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=github_pr_data
|
||||
)
|
||||
|
||||
result = await github_provider.update_pr(
|
||||
"owner", "repo", 42,
|
||||
title="Updated Title",
|
||||
body="Updated body",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_pr(self, github_provider, mock_github_httpx_client, github_pr_data):
|
||||
"""Test closing a pull request."""
|
||||
github_pr_data["state"] = "closed"
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=github_pr_data
|
||||
)
|
||||
|
||||
result = await github_provider.close_pr("owner", "repo", 42)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
|
||||
class TestGitHubBranchOperations:
|
||||
"""Tests for GitHub branch operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_branch(self, github_provider, mock_github_httpx_client):
|
||||
"""Test getting branch info."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={
|
||||
"name": "main",
|
||||
"commit": {"sha": "abc123"},
|
||||
}
|
||||
)
|
||||
|
||||
result = await github_provider.get_branch("owner", "repo", "main")
|
||||
|
||||
assert result["name"] == "main"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_remote_branch(self, github_provider, mock_github_httpx_client):
|
||||
"""Test deleting a remote branch."""
|
||||
mock_github_httpx_client.request.return_value.status_code = 204
|
||||
|
||||
result = await github_provider.delete_remote_branch("owner", "repo", "old-branch")
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestGitHubCommentOperations:
|
||||
"""Tests for GitHub comment operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_pr_comment(self, github_provider, mock_github_httpx_client):
|
||||
"""Test adding a comment to a PR."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"id": 1, "body": "Test comment"}
|
||||
)
|
||||
|
||||
result = await github_provider.add_pr_comment(
|
||||
"owner", "repo", 42, "Test comment"
|
||||
)
|
||||
|
||||
assert result["body"] == "Test comment"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_pr_comments(self, github_provider, mock_github_httpx_client):
|
||||
"""Test listing PR comments."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=[
|
||||
{"id": 1, "body": "Comment 1"},
|
||||
{"id": 2, "body": "Comment 2"},
|
||||
]
|
||||
)
|
||||
|
||||
result = await github_provider.list_pr_comments("owner", "repo", 42)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestGitHubLabelOperations:
|
||||
"""Tests for GitHub label operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_labels(self, github_provider, mock_github_httpx_client):
|
||||
"""Test adding labels to a PR."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=[{"name": "bug"}, {"name": "urgent"}]
|
||||
)
|
||||
|
||||
result = await github_provider.add_labels(
|
||||
"owner", "repo", 42, ["bug", "urgent"]
|
||||
)
|
||||
|
||||
assert "bug" in result
|
||||
assert "urgent" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_label(self, github_provider, mock_github_httpx_client):
|
||||
"""Test removing a label from a PR."""
|
||||
mock_responses = [
|
||||
None, # DELETE label
|
||||
{"labels": []}, # GET issue
|
||||
]
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses)
|
||||
|
||||
result = await github_provider.remove_label(
|
||||
"owner", "repo", 42, "bug"
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
class TestGitHubReviewerOperations:
|
||||
"""Tests for GitHub reviewer operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_review(self, github_provider, mock_github_httpx_client):
|
||||
"""Test requesting review from users."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(return_value={})
|
||||
|
||||
result = await github_provider.request_review(
|
||||
"owner", "repo", 42, ["reviewer1", "reviewer2"]
|
||||
)
|
||||
|
||||
assert result == ["reviewer1", "reviewer2"]
|
||||
|
||||
|
||||
class TestGitHubErrorHandling:
|
||||
"""Tests for error handling in GitHub provider."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_error(self, github_provider, mock_github_httpx_client):
|
||||
"""Test handling authentication errors."""
|
||||
mock_github_httpx_client.request.return_value.status_code = 401
|
||||
|
||||
with pytest.raises(AuthenticationError):
|
||||
await github_provider._request("GET", "/user")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_denied(self, github_provider, mock_github_httpx_client):
|
||||
"""Test handling permission denied errors."""
|
||||
mock_github_httpx_client.request.return_value.status_code = 403
|
||||
mock_github_httpx_client.request.return_value.text = "Permission denied"
|
||||
|
||||
with pytest.raises(AuthenticationError, match="Insufficient permissions"):
|
||||
await github_provider._request("GET", "/protected")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_error(self, github_provider, mock_github_httpx_client):
|
||||
"""Test handling rate limit errors."""
|
||||
mock_github_httpx_client.request.return_value.status_code = 403
|
||||
mock_github_httpx_client.request.return_value.text = "API rate limit exceeded"
|
||||
|
||||
with pytest.raises(APIError, match="rate limit"):
|
||||
await github_provider._request("GET", "/user")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error(self, github_provider, mock_github_httpx_client):
|
||||
"""Test handling general API errors."""
|
||||
mock_github_httpx_client.request.return_value.status_code = 500
|
||||
mock_github_httpx_client.request.return_value.text = "Internal Server Error"
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"message": "Server error"}
|
||||
)
|
||||
|
||||
with pytest.raises(APIError):
|
||||
await github_provider._request("GET", "/error")
|
||||
|
||||
|
||||
class TestGitHubPRParsing:
|
||||
"""Tests for PR data parsing."""
|
||||
|
||||
def test_parse_pr_open(self, github_provider, github_pr_data):
|
||||
"""Test parsing open PR."""
|
||||
pr_info = github_provider._parse_pr(github_pr_data)
|
||||
|
||||
assert pr_info.number == 42
|
||||
assert pr_info.state == PRState.OPEN
|
||||
assert pr_info.title == "Test PR"
|
||||
assert pr_info.source_branch == "feature-branch"
|
||||
assert pr_info.target_branch == "main"
|
||||
|
||||
def test_parse_pr_merged(self, github_provider, github_pr_data):
|
||||
"""Test parsing merged PR."""
|
||||
github_pr_data["merged_at"] = "2024-01-16T10:00:00Z"
|
||||
|
||||
pr_info = github_provider._parse_pr(github_pr_data)
|
||||
|
||||
assert pr_info.state == PRState.MERGED
|
||||
|
||||
def test_parse_pr_closed(self, github_provider, github_pr_data):
|
||||
"""Test parsing closed PR."""
|
||||
github_pr_data["state"] = "closed"
|
||||
github_pr_data["closed_at"] = "2024-01-16T10:00:00Z"
|
||||
|
||||
pr_info = github_provider._parse_pr(github_pr_data)
|
||||
|
||||
assert pr_info.state == PRState.CLOSED
|
||||
|
||||
def test_parse_pr_draft(self, github_provider, github_pr_data):
|
||||
"""Test parsing draft PR."""
|
||||
github_pr_data["draft"] = True
|
||||
|
||||
pr_info = github_provider._parse_pr(github_pr_data)
|
||||
|
||||
assert pr_info.draft is True
|
||||
|
||||
def test_parse_datetime_iso(self, github_provider):
|
||||
"""Test parsing ISO datetime strings."""
|
||||
dt = github_provider._parse_datetime("2024-01-15T10:30:00Z")
|
||||
|
||||
assert dt.year == 2024
|
||||
assert dt.month == 1
|
||||
assert dt.day == 15
|
||||
|
||||
def test_parse_datetime_none(self, github_provider):
|
||||
"""Test parsing None datetime returns now."""
|
||||
dt = github_provider._parse_datetime(None)
|
||||
|
||||
assert dt is not None
|
||||
assert dt.tzinfo is not None
|
||||
|
||||
def test_parse_pr_with_null_body(self, github_provider, github_pr_data):
|
||||
"""Test parsing PR with null body."""
|
||||
github_pr_data["body"] = None
|
||||
|
||||
pr_info = github_provider._parse_pr(github_pr_data)
|
||||
|
||||
assert pr_info.body == ""
|
||||
Reference in New Issue
Block a user