forked from cardosofelipe/fast-next-template
Implements GitHub API provider following the same pattern as Gitea: - Full PR operations (create, get, list, merge, update, close) - Branch operations via API - Comment and label management - Reviewer request support - Rate limit error handling Server enhancements: - Auto-detect provider from repository URL (github.com vs custom Gitea) - Initialize GitHub provider when token is configured - Health check includes both provider statuses - Token selection based on repo URL for clone/push operations Refs: #110 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
678 lines
20 KiB
Python
678 lines
20 KiB
Python
"""
|
|
GitHub provider implementation.
|
|
|
|
Implements the BaseProvider interface for GitHub API operations.
|
|
"""
|
|
|
|
import logging
|
|
from datetime import UTC, datetime
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from config import Settings, get_settings
|
|
from exceptions import (
|
|
APIError,
|
|
AuthenticationError,
|
|
PRNotFoundError,
|
|
)
|
|
from models import (
|
|
CreatePRResult,
|
|
GetPRResult,
|
|
ListPRsResult,
|
|
MergePRResult,
|
|
MergeStrategy,
|
|
PRInfo,
|
|
PRState,
|
|
UpdatePRResult,
|
|
)
|
|
|
|
from .base import BaseProvider
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class GitHubProvider(BaseProvider):
|
|
"""
|
|
GitHub API provider implementation.
|
|
|
|
Supports all PR operations, branch operations, and repository queries.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
token: str | None = None,
|
|
settings: Settings | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize GitHub provider.
|
|
|
|
Args:
|
|
token: GitHub personal access token or fine-grained token
|
|
settings: Optional settings override
|
|
"""
|
|
self.settings = settings or get_settings()
|
|
self.token = token or self.settings.github_token
|
|
self._client: httpx.AsyncClient | None = None
|
|
self._user: str | None = None
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
"""Return the provider name."""
|
|
return "github"
|
|
|
|
async def _get_client(self) -> httpx.AsyncClient:
|
|
"""Get or create HTTP client."""
|
|
if self._client is None:
|
|
headers = {
|
|
"Accept": "application/vnd.github+json",
|
|
"X-GitHub-Api-Version": "2022-11-28",
|
|
}
|
|
if self.token:
|
|
headers["Authorization"] = f"Bearer {self.token}"
|
|
|
|
self._client = httpx.AsyncClient(
|
|
base_url="https://api.github.com",
|
|
headers=headers,
|
|
timeout=30.0,
|
|
)
|
|
return self._client
|
|
|
|
async def close(self) -> None:
|
|
"""Close the HTTP client."""
|
|
if self._client:
|
|
await self._client.aclose()
|
|
self._client = None
|
|
|
|
async def _request(
|
|
self,
|
|
method: str,
|
|
path: str,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""
|
|
Make an API request.
|
|
|
|
Args:
|
|
method: HTTP method
|
|
path: API path
|
|
**kwargs: Additional request arguments
|
|
|
|
Returns:
|
|
Parsed JSON response
|
|
|
|
Raises:
|
|
APIError: On API errors
|
|
AuthenticationError: On auth failures
|
|
"""
|
|
client = await self._get_client()
|
|
|
|
try:
|
|
response = await client.request(method, path, **kwargs)
|
|
|
|
if response.status_code == 401:
|
|
raise AuthenticationError("github", "Invalid or expired token")
|
|
|
|
if response.status_code == 403:
|
|
# Check for rate limiting
|
|
if "rate limit" in response.text.lower():
|
|
raise APIError(
|
|
"github", 403, "GitHub API rate limit exceeded"
|
|
)
|
|
raise AuthenticationError(
|
|
"github", "Insufficient permissions for this operation"
|
|
)
|
|
|
|
if response.status_code == 404:
|
|
return None
|
|
|
|
if response.status_code >= 400:
|
|
error_msg = response.text
|
|
try:
|
|
error_data = response.json()
|
|
error_msg = error_data.get("message", error_msg)
|
|
except Exception:
|
|
pass
|
|
raise APIError("github", response.status_code, error_msg)
|
|
|
|
if response.status_code == 204:
|
|
return None
|
|
|
|
return response.json()
|
|
|
|
except httpx.RequestError as e:
|
|
raise APIError("github", 0, f"Request failed: {e}")
|
|
|
|
async def is_connected(self) -> bool:
|
|
"""Check if connected to GitHub."""
|
|
if not self.token:
|
|
return False
|
|
|
|
try:
|
|
result = await self._request("GET", "/user")
|
|
return result is not None
|
|
except Exception:
|
|
return False
|
|
|
|
async def get_authenticated_user(self) -> str | None:
|
|
"""Get the authenticated user's username."""
|
|
if self._user:
|
|
return self._user
|
|
|
|
try:
|
|
result = await self._request("GET", "/user")
|
|
if result:
|
|
self._user = result.get("login")
|
|
return self._user
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
# Repository operations
|
|
|
|
async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]:
|
|
"""Get repository information."""
|
|
result = await self._request("GET", f"/repos/{owner}/{repo}")
|
|
if result is None:
|
|
raise APIError("github", 404, f"Repository not found: {owner}/{repo}")
|
|
return result
|
|
|
|
async def get_default_branch(self, owner: str, repo: str) -> str:
|
|
"""Get the default branch for a repository."""
|
|
repo_info = await self.get_repo_info(owner, repo)
|
|
return repo_info.get("default_branch", "main")
|
|
|
|
# Pull Request operations
|
|
|
|
async def create_pr(
|
|
self,
|
|
owner: str,
|
|
repo: str,
|
|
title: str,
|
|
body: str,
|
|
source_branch: str,
|
|
target_branch: str,
|
|
draft: bool = False,
|
|
labels: list[str] | None = None,
|
|
assignees: list[str] | None = None,
|
|
reviewers: list[str] | None = None,
|
|
) -> CreatePRResult:
|
|
"""Create a pull request."""
|
|
try:
|
|
data: dict[str, Any] = {
|
|
"title": title,
|
|
"body": body,
|
|
"head": source_branch,
|
|
"base": target_branch,
|
|
"draft": draft,
|
|
}
|
|
|
|
result = await self._request(
|
|
"POST",
|
|
f"/repos/{owner}/{repo}/pulls",
|
|
json=data,
|
|
)
|
|
|
|
if result is None:
|
|
return CreatePRResult(
|
|
success=False,
|
|
error="Failed to create pull request",
|
|
)
|
|
|
|
pr_number = result["number"]
|
|
|
|
# Add labels if specified
|
|
if labels:
|
|
await self.add_labels(owner, repo, pr_number, labels)
|
|
|
|
# Add assignees if specified
|
|
if assignees:
|
|
await self._request(
|
|
"POST",
|
|
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
|
|
json={"assignees": assignees},
|
|
)
|
|
|
|
# Request reviewers if specified
|
|
if reviewers:
|
|
await self.request_review(owner, repo, pr_number, reviewers)
|
|
|
|
return CreatePRResult(
|
|
success=True,
|
|
pr_number=pr_number,
|
|
pr_url=result.get("html_url"),
|
|
)
|
|
|
|
except APIError as e:
|
|
return CreatePRResult(
|
|
success=False,
|
|
error=str(e),
|
|
)
|
|
|
|
async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult:
|
|
"""Get a pull request by number."""
|
|
try:
|
|
result = await self._request(
|
|
"GET",
|
|
f"/repos/{owner}/{repo}/pulls/{pr_number}",
|
|
)
|
|
|
|
if result is None:
|
|
raise PRNotFoundError(pr_number, f"{owner}/{repo}")
|
|
|
|
pr_info = self._parse_pr(result)
|
|
|
|
return GetPRResult(
|
|
success=True,
|
|
pr=pr_info.to_dict(),
|
|
)
|
|
|
|
except PRNotFoundError:
|
|
return GetPRResult(
|
|
success=False,
|
|
error=f"Pull request #{pr_number} not found",
|
|
)
|
|
except APIError as e:
|
|
return GetPRResult(
|
|
success=False,
|
|
error=str(e),
|
|
)
|
|
|
|
async def list_prs(
|
|
self,
|
|
owner: str,
|
|
repo: str,
|
|
state: PRState | None = None,
|
|
author: str | None = None,
|
|
limit: int = 20,
|
|
) -> ListPRsResult:
|
|
"""List pull requests."""
|
|
try:
|
|
params: dict[str, Any] = {
|
|
"per_page": min(limit, 100), # GitHub max is 100
|
|
}
|
|
|
|
if state:
|
|
# GitHub uses 'state' for open/closed only
|
|
# Merged PRs are closed PRs with merged_at set
|
|
if state == PRState.OPEN:
|
|
params["state"] = "open"
|
|
elif state in (PRState.CLOSED, PRState.MERGED):
|
|
params["state"] = "closed"
|
|
else:
|
|
params["state"] = "all"
|
|
|
|
result = await self._request(
|
|
"GET",
|
|
f"/repos/{owner}/{repo}/pulls",
|
|
params=params,
|
|
)
|
|
|
|
if result is None:
|
|
return ListPRsResult(
|
|
success=True,
|
|
pull_requests=[],
|
|
total_count=0,
|
|
)
|
|
|
|
prs = []
|
|
for pr_data in result:
|
|
# Filter by author if specified
|
|
if author:
|
|
pr_author = pr_data.get("user", {}).get("login", "")
|
|
if pr_author.lower() != author.lower():
|
|
continue
|
|
|
|
# Filter merged PRs if looking specifically for merged
|
|
if state == PRState.MERGED:
|
|
if not pr_data.get("merged_at"):
|
|
continue
|
|
|
|
pr_info = self._parse_pr(pr_data)
|
|
prs.append(pr_info.to_dict())
|
|
|
|
return ListPRsResult(
|
|
success=True,
|
|
pull_requests=prs,
|
|
total_count=len(prs),
|
|
)
|
|
|
|
except APIError as e:
|
|
return ListPRsResult(
|
|
success=False,
|
|
error=str(e),
|
|
)
|
|
|
|
async def merge_pr(
|
|
self,
|
|
owner: str,
|
|
repo: str,
|
|
pr_number: int,
|
|
merge_strategy: MergeStrategy = MergeStrategy.MERGE,
|
|
commit_message: str | None = None,
|
|
delete_branch: bool = True,
|
|
) -> MergePRResult:
|
|
"""Merge a pull request."""
|
|
try:
|
|
# Map merge strategy to GitHub's merge_method values
|
|
method_map = {
|
|
MergeStrategy.MERGE: "merge",
|
|
MergeStrategy.SQUASH: "squash",
|
|
MergeStrategy.REBASE: "rebase",
|
|
}
|
|
|
|
data: dict[str, Any] = {
|
|
"merge_method": method_map[merge_strategy],
|
|
}
|
|
|
|
if commit_message:
|
|
# For squash, commit_title and commit_message
|
|
# For merge, commit_title and commit_message
|
|
parts = commit_message.split("\n", 1)
|
|
data["commit_title"] = parts[0]
|
|
if len(parts) > 1:
|
|
data["commit_message"] = parts[1]
|
|
|
|
result = await self._request(
|
|
"PUT",
|
|
f"/repos/{owner}/{repo}/pulls/{pr_number}/merge",
|
|
json=data,
|
|
)
|
|
|
|
if result is None:
|
|
return MergePRResult(
|
|
success=False,
|
|
error="Failed to merge pull request",
|
|
)
|
|
|
|
branch_deleted = False
|
|
# Delete branch if requested
|
|
if delete_branch and result.get("merged"):
|
|
# Get PR to find the branch name
|
|
pr_result = await self.get_pr(owner, repo, pr_number)
|
|
if pr_result.success and pr_result.pr:
|
|
source_branch = pr_result.pr.get("source_branch")
|
|
if source_branch:
|
|
branch_deleted = await self.delete_remote_branch(
|
|
owner, repo, source_branch
|
|
)
|
|
|
|
return MergePRResult(
|
|
success=True,
|
|
merge_commit_sha=result.get("sha"),
|
|
branch_deleted=branch_deleted,
|
|
)
|
|
|
|
except APIError as e:
|
|
return MergePRResult(
|
|
success=False,
|
|
error=str(e),
|
|
)
|
|
|
|
async def update_pr(
|
|
self,
|
|
owner: str,
|
|
repo: str,
|
|
pr_number: int,
|
|
title: str | None = None,
|
|
body: str | None = None,
|
|
state: PRState | None = None,
|
|
labels: list[str] | None = None,
|
|
assignees: list[str] | None = None,
|
|
) -> UpdatePRResult:
|
|
"""Update a pull request."""
|
|
try:
|
|
data: dict[str, Any] = {}
|
|
|
|
if title is not None:
|
|
data["title"] = title
|
|
if body is not None:
|
|
data["body"] = body
|
|
if state is not None:
|
|
if state == PRState.OPEN:
|
|
data["state"] = "open"
|
|
elif state == PRState.CLOSED:
|
|
data["state"] = "closed"
|
|
|
|
# Update PR if there's data
|
|
if data:
|
|
await self._request(
|
|
"PATCH",
|
|
f"/repos/{owner}/{repo}/pulls/{pr_number}",
|
|
json=data,
|
|
)
|
|
|
|
# Update labels via issue endpoint
|
|
if labels is not None:
|
|
await self._request(
|
|
"PUT",
|
|
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
|
|
json={"labels": labels},
|
|
)
|
|
|
|
# Update assignees via issue endpoint
|
|
if assignees is not None:
|
|
# First remove all assignees
|
|
await self._request(
|
|
"DELETE",
|
|
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
|
|
json={"assignees": []},
|
|
)
|
|
# Then add new ones
|
|
if assignees:
|
|
await self._request(
|
|
"POST",
|
|
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
|
|
json={"assignees": assignees},
|
|
)
|
|
|
|
# Fetch updated PR
|
|
result = await self.get_pr(owner, repo, pr_number)
|
|
return UpdatePRResult(
|
|
success=result.success,
|
|
pr=result.pr,
|
|
error=result.error,
|
|
)
|
|
|
|
except APIError as e:
|
|
return UpdatePRResult(
|
|
success=False,
|
|
error=str(e),
|
|
)
|
|
|
|
async def close_pr(
|
|
self,
|
|
owner: str,
|
|
repo: str,
|
|
pr_number: int,
|
|
) -> UpdatePRResult:
|
|
"""Close a pull request without merging."""
|
|
return await self.update_pr(
|
|
owner,
|
|
repo,
|
|
pr_number,
|
|
state=PRState.CLOSED,
|
|
)
|
|
|
|
# Branch operations
|
|
|
|
async def delete_remote_branch(
|
|
self,
|
|
owner: str,
|
|
repo: str,
|
|
branch: str,
|
|
) -> bool:
|
|
"""Delete a remote branch."""
|
|
try:
|
|
await self._request(
|
|
"DELETE",
|
|
f"/repos/{owner}/{repo}/git/refs/heads/{branch}",
|
|
)
|
|
return True
|
|
except APIError:
|
|
return False
|
|
|
|
async def get_branch(
|
|
self,
|
|
owner: str,
|
|
repo: str,
|
|
branch: str,
|
|
) -> dict[str, Any] | None:
|
|
"""Get branch information."""
|
|
return await self._request(
|
|
"GET",
|
|
f"/repos/{owner}/{repo}/branches/{branch}",
|
|
)
|
|
|
|
# Comment operations
|
|
|
|
async def add_pr_comment(
|
|
self,
|
|
owner: str,
|
|
repo: str,
|
|
pr_number: int,
|
|
body: str,
|
|
) -> dict[str, Any]:
|
|
"""Add a comment to a pull request."""
|
|
result = await self._request(
|
|
"POST",
|
|
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
|
|
json={"body": body},
|
|
)
|
|
return result or {}
|
|
|
|
async def list_pr_comments(
|
|
self,
|
|
owner: str,
|
|
repo: str,
|
|
pr_number: int,
|
|
) -> list[dict[str, Any]]:
|
|
"""List comments on a pull request."""
|
|
result = await self._request(
|
|
"GET",
|
|
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
|
|
)
|
|
return result or []
|
|
|
|
# Label operations
|
|
|
|
async def add_labels(
|
|
self,
|
|
owner: str,
|
|
repo: str,
|
|
pr_number: int,
|
|
labels: list[str],
|
|
) -> list[str]:
|
|
"""Add labels to a pull request."""
|
|
# GitHub creates labels automatically if they don't exist (unlike Gitea)
|
|
result = await self._request(
|
|
"POST",
|
|
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
|
|
json={"labels": labels},
|
|
)
|
|
|
|
if result:
|
|
return [lbl["name"] for lbl in result]
|
|
return labels
|
|
|
|
async def remove_label(
|
|
self,
|
|
owner: str,
|
|
repo: str,
|
|
pr_number: int,
|
|
label: str,
|
|
) -> list[str]:
|
|
"""Remove a label from a pull request."""
|
|
await self._request(
|
|
"DELETE",
|
|
f"/repos/{owner}/{repo}/issues/{pr_number}/labels/{label}",
|
|
)
|
|
|
|
# Return remaining labels
|
|
issue = await self._request(
|
|
"GET",
|
|
f"/repos/{owner}/{repo}/issues/{pr_number}",
|
|
)
|
|
if issue:
|
|
return [lbl["name"] for lbl in issue.get("labels", [])]
|
|
return []
|
|
|
|
# Reviewer operations
|
|
|
|
async def request_review(
|
|
self,
|
|
owner: str,
|
|
repo: str,
|
|
pr_number: int,
|
|
reviewers: list[str],
|
|
) -> list[str]:
|
|
"""Request review from users."""
|
|
await self._request(
|
|
"POST",
|
|
f"/repos/{owner}/{repo}/pulls/{pr_number}/requested_reviewers",
|
|
json={"reviewers": reviewers},
|
|
)
|
|
return reviewers
|
|
|
|
# Helper methods
|
|
|
|
def _parse_pr(self, data: dict[str, Any]) -> PRInfo:
|
|
"""Parse PR API response into PRInfo."""
|
|
# Parse dates
|
|
created_at = self._parse_datetime(data.get("created_at"))
|
|
updated_at = self._parse_datetime(data.get("updated_at"))
|
|
merged_at = self._parse_datetime(data.get("merged_at"))
|
|
closed_at = self._parse_datetime(data.get("closed_at"))
|
|
|
|
# Determine state
|
|
if data.get("merged_at"):
|
|
state = PRState.MERGED
|
|
elif data.get("state") == "closed":
|
|
state = PRState.CLOSED
|
|
else:
|
|
state = PRState.OPEN
|
|
|
|
# Extract labels
|
|
labels = [lbl["name"] for lbl in data.get("labels", [])]
|
|
|
|
# Extract assignees
|
|
assignees = [a["login"] for a in data.get("assignees", [])]
|
|
|
|
# Extract reviewers
|
|
reviewers = []
|
|
if "requested_reviewers" in data:
|
|
reviewers = [r["login"] for r in data["requested_reviewers"]]
|
|
|
|
return PRInfo(
|
|
number=data["number"],
|
|
title=data["title"],
|
|
body=data.get("body", "") or "",
|
|
state=state,
|
|
source_branch=data.get("head", {}).get("ref", ""),
|
|
target_branch=data.get("base", {}).get("ref", ""),
|
|
author=data.get("user", {}).get("login", ""),
|
|
created_at=created_at,
|
|
updated_at=updated_at,
|
|
merged_at=merged_at,
|
|
closed_at=closed_at,
|
|
url=data.get("html_url"),
|
|
labels=labels,
|
|
assignees=assignees,
|
|
reviewers=reviewers,
|
|
mergeable=data.get("mergeable"),
|
|
draft=data.get("draft", False),
|
|
)
|
|
|
|
def _parse_datetime(self, value: str | None) -> datetime:
|
|
"""Parse datetime string from API."""
|
|
if not value:
|
|
return datetime.now(UTC)
|
|
|
|
try:
|
|
# GitHub uses ISO 8601 format with Z suffix
|
|
if value.endswith("Z"):
|
|
value = value[:-1] + "+00:00"
|
|
return datetime.fromisoformat(value)
|
|
except ValueError:
|
|
return datetime.now(UTC)
|