""" GitHub provider implementation. Implements the BaseProvider interface for GitHub API operations. """ import logging from datetime import UTC, datetime from typing import Any import httpx from config import Settings, get_settings from exceptions import ( APIError, AuthenticationError, PRNotFoundError, ) from models import ( CreatePRResult, GetPRResult, ListPRsResult, MergePRResult, MergeStrategy, PRInfo, PRState, UpdatePRResult, ) from .base import BaseProvider logger = logging.getLogger(__name__) class GitHubProvider(BaseProvider): """ GitHub API provider implementation. Supports all PR operations, branch operations, and repository queries. """ def __init__( self, token: str | None = None, settings: Settings | None = None, ) -> None: """ Initialize GitHub provider. Args: token: GitHub personal access token or fine-grained token settings: Optional settings override """ self.settings = settings or get_settings() self.token = token or self.settings.github_token self._client: httpx.AsyncClient | None = None self._user: str | None = None @property def name(self) -> str: """Return the provider name.""" return "github" async def _get_client(self) -> httpx.AsyncClient: """Get or create HTTP client.""" if self._client is None: headers = { "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", } if self.token: headers["Authorization"] = f"Bearer {self.token}" self._client = httpx.AsyncClient( base_url="https://api.github.com", headers=headers, timeout=30.0, ) return self._client async def close(self) -> None: """Close the HTTP client.""" if self._client: await self._client.aclose() self._client = None async def _request( self, method: str, path: str, **kwargs: Any, ) -> Any: """ Make an API request. Args: method: HTTP method path: API path **kwargs: Additional request arguments Returns: Parsed JSON response Raises: APIError: On API errors AuthenticationError: On auth failures """ client = await self._get_client() try: response = await client.request(method, path, **kwargs) if response.status_code == 401: raise AuthenticationError("github", "Invalid or expired token") if response.status_code == 403: # Check for rate limiting if "rate limit" in response.text.lower(): raise APIError("github", 403, "GitHub API rate limit exceeded") raise AuthenticationError( "github", "Insufficient permissions for this operation" ) if response.status_code == 404: return None if response.status_code >= 400: error_msg = response.text try: error_data = response.json() error_msg = error_data.get("message", error_msg) except Exception: pass raise APIError("github", response.status_code, error_msg) if response.status_code == 204: return None return response.json() except httpx.RequestError as e: raise APIError("github", 0, f"Request failed: {e}") async def is_connected(self) -> bool: """Check if connected to GitHub.""" if not self.token: return False try: result = await self._request("GET", "/user") return result is not None except Exception: return False async def get_authenticated_user(self) -> str | None: """Get the authenticated user's username.""" if self._user: return self._user try: result = await self._request("GET", "/user") if result: self._user = result.get("login") return self._user except Exception: pass return None # Repository operations async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]: """Get repository information.""" result = await self._request("GET", f"/repos/{owner}/{repo}") if result is None: raise APIError("github", 404, f"Repository not found: {owner}/{repo}") return result async def get_default_branch(self, owner: str, repo: str) -> str: """Get the default branch for a repository.""" repo_info = await self.get_repo_info(owner, repo) return repo_info.get("default_branch", "main") # Pull Request operations async def create_pr( self, owner: str, repo: str, title: str, body: str, source_branch: str, target_branch: str, draft: bool = False, labels: list[str] | None = None, assignees: list[str] | None = None, reviewers: list[str] | None = None, ) -> CreatePRResult: """Create a pull request.""" try: data: dict[str, Any] = { "title": title, "body": body, "head": source_branch, "base": target_branch, "draft": draft, } result = await self._request( "POST", f"/repos/{owner}/{repo}/pulls", json=data, ) if result is None: return CreatePRResult( success=False, error="Failed to create pull request", ) pr_number = result["number"] # Add labels if specified if labels: await self.add_labels(owner, repo, pr_number, labels) # Add assignees if specified if assignees: await self._request( "POST", f"/repos/{owner}/{repo}/issues/{pr_number}/assignees", json={"assignees": assignees}, ) # Request reviewers if specified if reviewers: await self.request_review(owner, repo, pr_number, reviewers) return CreatePRResult( success=True, pr_number=pr_number, pr_url=result.get("html_url"), ) except APIError as e: return CreatePRResult( success=False, error=str(e), ) async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult: """Get a pull request by number.""" try: result = await self._request( "GET", f"/repos/{owner}/{repo}/pulls/{pr_number}", ) if result is None: raise PRNotFoundError(pr_number, f"{owner}/{repo}") pr_info = self._parse_pr(result) return GetPRResult( success=True, pr=pr_info.to_dict(), ) except PRNotFoundError: return GetPRResult( success=False, error=f"Pull request #{pr_number} not found", ) except APIError as e: return GetPRResult( success=False, error=str(e), ) async def list_prs( self, owner: str, repo: str, state: PRState | None = None, author: str | None = None, limit: int = 20, ) -> ListPRsResult: """List pull requests.""" try: params: dict[str, Any] = { "per_page": min(limit, 100), # GitHub max is 100 } if state: # GitHub uses 'state' for open/closed only # Merged PRs are closed PRs with merged_at set if state == PRState.OPEN: params["state"] = "open" elif state in (PRState.CLOSED, PRState.MERGED): params["state"] = "closed" else: params["state"] = "all" result = await self._request( "GET", f"/repos/{owner}/{repo}/pulls", params=params, ) if result is None: return ListPRsResult( success=True, pull_requests=[], total_count=0, ) prs = [] for pr_data in result: # Filter by author if specified if author: pr_author = pr_data.get("user", {}).get("login", "") if pr_author.lower() != author.lower(): continue # Filter merged PRs if looking specifically for merged if state == PRState.MERGED: if not pr_data.get("merged_at"): continue pr_info = self._parse_pr(pr_data) prs.append(pr_info.to_dict()) return ListPRsResult( success=True, pull_requests=prs, total_count=len(prs), ) except APIError as e: return ListPRsResult( success=False, error=str(e), ) async def merge_pr( self, owner: str, repo: str, pr_number: int, merge_strategy: MergeStrategy = MergeStrategy.MERGE, commit_message: str | None = None, delete_branch: bool = True, ) -> MergePRResult: """Merge a pull request.""" try: # Map merge strategy to GitHub's merge_method values method_map = { MergeStrategy.MERGE: "merge", MergeStrategy.SQUASH: "squash", MergeStrategy.REBASE: "rebase", } data: dict[str, Any] = { "merge_method": method_map[merge_strategy], } if commit_message: # For squash, commit_title and commit_message # For merge, commit_title and commit_message parts = commit_message.split("\n", 1) data["commit_title"] = parts[0] if len(parts) > 1: data["commit_message"] = parts[1] result = await self._request( "PUT", f"/repos/{owner}/{repo}/pulls/{pr_number}/merge", json=data, ) if result is None: return MergePRResult( success=False, error="Failed to merge pull request", ) branch_deleted = False # Delete branch if requested if delete_branch and result.get("merged"): # Get PR to find the branch name pr_result = await self.get_pr(owner, repo, pr_number) if pr_result.success and pr_result.pr: source_branch = pr_result.pr.get("source_branch") if source_branch: branch_deleted = await self.delete_remote_branch( owner, repo, source_branch ) return MergePRResult( success=True, merge_commit_sha=result.get("sha"), branch_deleted=branch_deleted, ) except APIError as e: return MergePRResult( success=False, error=str(e), ) async def update_pr( self, owner: str, repo: str, pr_number: int, title: str | None = None, body: str | None = None, state: PRState | None = None, labels: list[str] | None = None, assignees: list[str] | None = None, ) -> UpdatePRResult: """Update a pull request.""" try: data: dict[str, Any] = {} if title is not None: data["title"] = title if body is not None: data["body"] = body if state is not None: if state == PRState.OPEN: data["state"] = "open" elif state == PRState.CLOSED: data["state"] = "closed" # Update PR if there's data if data: await self._request( "PATCH", f"/repos/{owner}/{repo}/pulls/{pr_number}", json=data, ) # Update labels via issue endpoint if labels is not None: await self._request( "PUT", f"/repos/{owner}/{repo}/issues/{pr_number}/labels", json={"labels": labels}, ) # Update assignees via issue endpoint if assignees is not None: # First remove all assignees await self._request( "DELETE", f"/repos/{owner}/{repo}/issues/{pr_number}/assignees", json={"assignees": []}, ) # Then add new ones if assignees: await self._request( "POST", f"/repos/{owner}/{repo}/issues/{pr_number}/assignees", json={"assignees": assignees}, ) # Fetch updated PR result = await self.get_pr(owner, repo, pr_number) return UpdatePRResult( success=result.success, pr=result.pr, error=result.error, ) except APIError as e: return UpdatePRResult( success=False, error=str(e), ) async def close_pr( self, owner: str, repo: str, pr_number: int, ) -> UpdatePRResult: """Close a pull request without merging.""" return await self.update_pr( owner, repo, pr_number, state=PRState.CLOSED, ) # Branch operations async def delete_remote_branch( self, owner: str, repo: str, branch: str, ) -> bool: """Delete a remote branch.""" try: await self._request( "DELETE", f"/repos/{owner}/{repo}/git/refs/heads/{branch}", ) return True except APIError: return False async def get_branch( self, owner: str, repo: str, branch: str, ) -> dict[str, Any] | None: """Get branch information.""" return await self._request( "GET", f"/repos/{owner}/{repo}/branches/{branch}", ) # Comment operations async def add_pr_comment( self, owner: str, repo: str, pr_number: int, body: str, ) -> dict[str, Any]: """Add a comment to a pull request.""" result = await self._request( "POST", f"/repos/{owner}/{repo}/issues/{pr_number}/comments", json={"body": body}, ) return result or {} async def list_pr_comments( self, owner: str, repo: str, pr_number: int, ) -> list[dict[str, Any]]: """List comments on a pull request.""" result = await self._request( "GET", f"/repos/{owner}/{repo}/issues/{pr_number}/comments", ) return result or [] # Label operations async def add_labels( self, owner: str, repo: str, pr_number: int, labels: list[str], ) -> list[str]: """Add labels to a pull request.""" # GitHub creates labels automatically if they don't exist (unlike Gitea) result = await self._request( "POST", f"/repos/{owner}/{repo}/issues/{pr_number}/labels", json={"labels": labels}, ) if result: return [lbl["name"] for lbl in result] return labels async def remove_label( self, owner: str, repo: str, pr_number: int, label: str, ) -> list[str]: """Remove a label from a pull request.""" await self._request( "DELETE", f"/repos/{owner}/{repo}/issues/{pr_number}/labels/{label}", ) # Return remaining labels issue = await self._request( "GET", f"/repos/{owner}/{repo}/issues/{pr_number}", ) if issue: return [lbl["name"] for lbl in issue.get("labels", [])] return [] # Reviewer operations async def request_review( self, owner: str, repo: str, pr_number: int, reviewers: list[str], ) -> list[str]: """Request review from users.""" await self._request( "POST", f"/repos/{owner}/{repo}/pulls/{pr_number}/requested_reviewers", json={"reviewers": reviewers}, ) return reviewers # Helper methods def _parse_pr(self, data: dict[str, Any]) -> PRInfo: """Parse PR API response into PRInfo.""" # Parse dates created_at = self._parse_datetime(data.get("created_at")) updated_at = self._parse_datetime(data.get("updated_at")) merged_at = self._parse_datetime(data.get("merged_at")) closed_at = self._parse_datetime(data.get("closed_at")) # Determine state if data.get("merged_at"): state = PRState.MERGED elif data.get("state") == "closed": state = PRState.CLOSED else: state = PRState.OPEN # Extract labels labels = [lbl["name"] for lbl in data.get("labels", [])] # Extract assignees assignees = [a["login"] for a in data.get("assignees", [])] # Extract reviewers reviewers = [] if "requested_reviewers" in data: reviewers = [r["login"] for r in data["requested_reviewers"]] return PRInfo( number=data["number"], title=data["title"], body=data.get("body", "") or "", state=state, source_branch=data.get("head", {}).get("ref", ""), target_branch=data.get("base", {}).get("ref", ""), author=data.get("user", {}).get("login", ""), created_at=created_at, updated_at=updated_at, merged_at=merged_at, closed_at=closed_at, url=data.get("html_url"), labels=labels, assignees=assignees, reviewers=reviewers, mergeable=data.get("mergeable"), draft=data.get("draft", False), ) def _parse_datetime(self, value: str | None) -> datetime: """Parse datetime string from API.""" if not value: return datetime.now(UTC) try: # GitHub uses ISO 8601 format with Z suffix if value.endswith("Z"): value = value[:-1] + "+00:00" return datetime.fromisoformat(value) except ValueError: return datetime.now(UTC)