From 1779239c074bb3b0e14b879af5ce062a7072e7fa Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Tue, 6 Jan 2026 20:55:22 +0100 Subject: [PATCH] feat(git-ops): add GitHub provider with auto-detection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements GitHub API provider following the same pattern as Gitea: - Full PR operations (create, get, list, merge, update, close) - Branch operations via API - Comment and label management - Reviewer request support - Rate limit error handling Server enhancements: - Auto-detect provider from repository URL (github.com vs custom Gitea) - Initialize GitHub provider when token is configured - Health check includes both provider statuses - Token selection based on repo URL for clone/push operations Refs: #110 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- mcp-servers/git-ops/providers/__init__.py | 3 +- mcp-servers/git-ops/providers/github.py | 677 ++++++++++++++++++ mcp-servers/git-ops/server.py | 79 +- .../git-ops/tests/test_github_provider.py | 583 +++++++++++++++ 4 files changed, 1328 insertions(+), 14 deletions(-) create mode 100644 mcp-servers/git-ops/providers/github.py create mode 100644 mcp-servers/git-ops/tests/test_github_provider.py diff --git a/mcp-servers/git-ops/providers/__init__.py b/mcp-servers/git-ops/providers/__init__.py index c03b187..3f6cc06 100644 --- a/mcp-servers/git-ops/providers/__init__.py +++ b/mcp-servers/git-ops/providers/__init__.py @@ -6,5 +6,6 @@ Provides adapters for different git hosting platforms (Gitea, GitHub, GitLab). from .base import BaseProvider from .gitea import GiteaProvider +from .github import GitHubProvider -__all__ = ["BaseProvider", "GiteaProvider"] +__all__ = ["BaseProvider", "GiteaProvider", "GitHubProvider"] diff --git a/mcp-servers/git-ops/providers/github.py b/mcp-servers/git-ops/providers/github.py new file mode 100644 index 0000000..d1d45a7 --- /dev/null +++ b/mcp-servers/git-ops/providers/github.py @@ -0,0 +1,677 @@ +""" +GitHub provider implementation. + +Implements the BaseProvider interface for GitHub API operations. +""" + +import logging +from datetime import UTC, datetime +from typing import Any + +import httpx + +from config import Settings, get_settings +from exceptions import ( + APIError, + AuthenticationError, + PRNotFoundError, +) +from models import ( + CreatePRResult, + GetPRResult, + ListPRsResult, + MergePRResult, + MergeStrategy, + PRInfo, + PRState, + UpdatePRResult, +) + +from .base import BaseProvider + +logger = logging.getLogger(__name__) + + +class GitHubProvider(BaseProvider): + """ + GitHub API provider implementation. + + Supports all PR operations, branch operations, and repository queries. + """ + + def __init__( + self, + token: str | None = None, + settings: Settings | None = None, + ) -> None: + """ + Initialize GitHub provider. + + Args: + token: GitHub personal access token or fine-grained token + settings: Optional settings override + """ + self.settings = settings or get_settings() + self.token = token or self.settings.github_token + self._client: httpx.AsyncClient | None = None + self._user: str | None = None + + @property + def name(self) -> str: + """Return the provider name.""" + return "github" + + async def _get_client(self) -> httpx.AsyncClient: + """Get or create HTTP client.""" + if self._client is None: + headers = { + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + } + if self.token: + headers["Authorization"] = f"Bearer {self.token}" + + self._client = httpx.AsyncClient( + base_url="https://api.github.com", + headers=headers, + timeout=30.0, + ) + return self._client + + async def close(self) -> None: + """Close the HTTP client.""" + if self._client: + await self._client.aclose() + self._client = None + + async def _request( + self, + method: str, + path: str, + **kwargs: Any, + ) -> Any: + """ + Make an API request. + + Args: + method: HTTP method + path: API path + **kwargs: Additional request arguments + + Returns: + Parsed JSON response + + Raises: + APIError: On API errors + AuthenticationError: On auth failures + """ + client = await self._get_client() + + try: + response = await client.request(method, path, **kwargs) + + if response.status_code == 401: + raise AuthenticationError("github", "Invalid or expired token") + + if response.status_code == 403: + # Check for rate limiting + if "rate limit" in response.text.lower(): + raise APIError( + "github", 403, "GitHub API rate limit exceeded" + ) + raise AuthenticationError( + "github", "Insufficient permissions for this operation" + ) + + if response.status_code == 404: + return None + + if response.status_code >= 400: + error_msg = response.text + try: + error_data = response.json() + error_msg = error_data.get("message", error_msg) + except Exception: + pass + raise APIError("github", response.status_code, error_msg) + + if response.status_code == 204: + return None + + return response.json() + + except httpx.RequestError as e: + raise APIError("github", 0, f"Request failed: {e}") + + async def is_connected(self) -> bool: + """Check if connected to GitHub.""" + if not self.token: + return False + + try: + result = await self._request("GET", "/user") + return result is not None + except Exception: + return False + + async def get_authenticated_user(self) -> str | None: + """Get the authenticated user's username.""" + if self._user: + return self._user + + try: + result = await self._request("GET", "/user") + if result: + self._user = result.get("login") + return self._user + except Exception: + pass + return None + + # Repository operations + + async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]: + """Get repository information.""" + result = await self._request("GET", f"/repos/{owner}/{repo}") + if result is None: + raise APIError("github", 404, f"Repository not found: {owner}/{repo}") + return result + + async def get_default_branch(self, owner: str, repo: str) -> str: + """Get the default branch for a repository.""" + repo_info = await self.get_repo_info(owner, repo) + return repo_info.get("default_branch", "main") + + # Pull Request operations + + async def create_pr( + self, + owner: str, + repo: str, + title: str, + body: str, + source_branch: str, + target_branch: str, + draft: bool = False, + labels: list[str] | None = None, + assignees: list[str] | None = None, + reviewers: list[str] | None = None, + ) -> CreatePRResult: + """Create a pull request.""" + try: + data: dict[str, Any] = { + "title": title, + "body": body, + "head": source_branch, + "base": target_branch, + "draft": draft, + } + + result = await self._request( + "POST", + f"/repos/{owner}/{repo}/pulls", + json=data, + ) + + if result is None: + return CreatePRResult( + success=False, + error="Failed to create pull request", + ) + + pr_number = result["number"] + + # Add labels if specified + if labels: + await self.add_labels(owner, repo, pr_number, labels) + + # Add assignees if specified + if assignees: + await self._request( + "POST", + f"/repos/{owner}/{repo}/issues/{pr_number}/assignees", + json={"assignees": assignees}, + ) + + # Request reviewers if specified + if reviewers: + await self.request_review(owner, repo, pr_number, reviewers) + + return CreatePRResult( + success=True, + pr_number=pr_number, + pr_url=result.get("html_url"), + ) + + except APIError as e: + return CreatePRResult( + success=False, + error=str(e), + ) + + async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult: + """Get a pull request by number.""" + try: + result = await self._request( + "GET", + f"/repos/{owner}/{repo}/pulls/{pr_number}", + ) + + if result is None: + raise PRNotFoundError(pr_number, f"{owner}/{repo}") + + pr_info = self._parse_pr(result) + + return GetPRResult( + success=True, + pr=pr_info.to_dict(), + ) + + except PRNotFoundError: + return GetPRResult( + success=False, + error=f"Pull request #{pr_number} not found", + ) + except APIError as e: + return GetPRResult( + success=False, + error=str(e), + ) + + async def list_prs( + self, + owner: str, + repo: str, + state: PRState | None = None, + author: str | None = None, + limit: int = 20, + ) -> ListPRsResult: + """List pull requests.""" + try: + params: dict[str, Any] = { + "per_page": min(limit, 100), # GitHub max is 100 + } + + if state: + # GitHub uses 'state' for open/closed only + # Merged PRs are closed PRs with merged_at set + if state == PRState.OPEN: + params["state"] = "open" + elif state in (PRState.CLOSED, PRState.MERGED): + params["state"] = "closed" + else: + params["state"] = "all" + + result = await self._request( + "GET", + f"/repos/{owner}/{repo}/pulls", + params=params, + ) + + if result is None: + return ListPRsResult( + success=True, + pull_requests=[], + total_count=0, + ) + + prs = [] + for pr_data in result: + # Filter by author if specified + if author: + pr_author = pr_data.get("user", {}).get("login", "") + if pr_author.lower() != author.lower(): + continue + + # Filter merged PRs if looking specifically for merged + if state == PRState.MERGED: + if not pr_data.get("merged_at"): + continue + + pr_info = self._parse_pr(pr_data) + prs.append(pr_info.to_dict()) + + return ListPRsResult( + success=True, + pull_requests=prs, + total_count=len(prs), + ) + + except APIError as e: + return ListPRsResult( + success=False, + error=str(e), + ) + + async def merge_pr( + self, + owner: str, + repo: str, + pr_number: int, + merge_strategy: MergeStrategy = MergeStrategy.MERGE, + commit_message: str | None = None, + delete_branch: bool = True, + ) -> MergePRResult: + """Merge a pull request.""" + try: + # Map merge strategy to GitHub's merge_method values + method_map = { + MergeStrategy.MERGE: "merge", + MergeStrategy.SQUASH: "squash", + MergeStrategy.REBASE: "rebase", + } + + data: dict[str, Any] = { + "merge_method": method_map[merge_strategy], + } + + if commit_message: + # For squash, commit_title and commit_message + # For merge, commit_title and commit_message + parts = commit_message.split("\n", 1) + data["commit_title"] = parts[0] + if len(parts) > 1: + data["commit_message"] = parts[1] + + result = await self._request( + "PUT", + f"/repos/{owner}/{repo}/pulls/{pr_number}/merge", + json=data, + ) + + if result is None: + return MergePRResult( + success=False, + error="Failed to merge pull request", + ) + + branch_deleted = False + # Delete branch if requested + if delete_branch and result.get("merged"): + # Get PR to find the branch name + pr_result = await self.get_pr(owner, repo, pr_number) + if pr_result.success and pr_result.pr: + source_branch = pr_result.pr.get("source_branch") + if source_branch: + branch_deleted = await self.delete_remote_branch( + owner, repo, source_branch + ) + + return MergePRResult( + success=True, + merge_commit_sha=result.get("sha"), + branch_deleted=branch_deleted, + ) + + except APIError as e: + return MergePRResult( + success=False, + error=str(e), + ) + + async def update_pr( + self, + owner: str, + repo: str, + pr_number: int, + title: str | None = None, + body: str | None = None, + state: PRState | None = None, + labels: list[str] | None = None, + assignees: list[str] | None = None, + ) -> UpdatePRResult: + """Update a pull request.""" + try: + data: dict[str, Any] = {} + + if title is not None: + data["title"] = title + if body is not None: + data["body"] = body + if state is not None: + if state == PRState.OPEN: + data["state"] = "open" + elif state == PRState.CLOSED: + data["state"] = "closed" + + # Update PR if there's data + if data: + await self._request( + "PATCH", + f"/repos/{owner}/{repo}/pulls/{pr_number}", + json=data, + ) + + # Update labels via issue endpoint + if labels is not None: + await self._request( + "PUT", + f"/repos/{owner}/{repo}/issues/{pr_number}/labels", + json={"labels": labels}, + ) + + # Update assignees via issue endpoint + if assignees is not None: + # First remove all assignees + await self._request( + "DELETE", + f"/repos/{owner}/{repo}/issues/{pr_number}/assignees", + json={"assignees": []}, + ) + # Then add new ones + if assignees: + await self._request( + "POST", + f"/repos/{owner}/{repo}/issues/{pr_number}/assignees", + json={"assignees": assignees}, + ) + + # Fetch updated PR + result = await self.get_pr(owner, repo, pr_number) + return UpdatePRResult( + success=result.success, + pr=result.pr, + error=result.error, + ) + + except APIError as e: + return UpdatePRResult( + success=False, + error=str(e), + ) + + async def close_pr( + self, + owner: str, + repo: str, + pr_number: int, + ) -> UpdatePRResult: + """Close a pull request without merging.""" + return await self.update_pr( + owner, + repo, + pr_number, + state=PRState.CLOSED, + ) + + # Branch operations + + async def delete_remote_branch( + self, + owner: str, + repo: str, + branch: str, + ) -> bool: + """Delete a remote branch.""" + try: + await self._request( + "DELETE", + f"/repos/{owner}/{repo}/git/refs/heads/{branch}", + ) + return True + except APIError: + return False + + async def get_branch( + self, + owner: str, + repo: str, + branch: str, + ) -> dict[str, Any] | None: + """Get branch information.""" + return await self._request( + "GET", + f"/repos/{owner}/{repo}/branches/{branch}", + ) + + # Comment operations + + async def add_pr_comment( + self, + owner: str, + repo: str, + pr_number: int, + body: str, + ) -> dict[str, Any]: + """Add a comment to a pull request.""" + result = await self._request( + "POST", + f"/repos/{owner}/{repo}/issues/{pr_number}/comments", + json={"body": body}, + ) + return result or {} + + async def list_pr_comments( + self, + owner: str, + repo: str, + pr_number: int, + ) -> list[dict[str, Any]]: + """List comments on a pull request.""" + result = await self._request( + "GET", + f"/repos/{owner}/{repo}/issues/{pr_number}/comments", + ) + return result or [] + + # Label operations + + async def add_labels( + self, + owner: str, + repo: str, + pr_number: int, + labels: list[str], + ) -> list[str]: + """Add labels to a pull request.""" + # GitHub creates labels automatically if they don't exist (unlike Gitea) + result = await self._request( + "POST", + f"/repos/{owner}/{repo}/issues/{pr_number}/labels", + json={"labels": labels}, + ) + + if result: + return [lbl["name"] for lbl in result] + return labels + + async def remove_label( + self, + owner: str, + repo: str, + pr_number: int, + label: str, + ) -> list[str]: + """Remove a label from a pull request.""" + await self._request( + "DELETE", + f"/repos/{owner}/{repo}/issues/{pr_number}/labels/{label}", + ) + + # Return remaining labels + issue = await self._request( + "GET", + f"/repos/{owner}/{repo}/issues/{pr_number}", + ) + if issue: + return [lbl["name"] for lbl in issue.get("labels", [])] + return [] + + # Reviewer operations + + async def request_review( + self, + owner: str, + repo: str, + pr_number: int, + reviewers: list[str], + ) -> list[str]: + """Request review from users.""" + await self._request( + "POST", + f"/repos/{owner}/{repo}/pulls/{pr_number}/requested_reviewers", + json={"reviewers": reviewers}, + ) + return reviewers + + # Helper methods + + def _parse_pr(self, data: dict[str, Any]) -> PRInfo: + """Parse PR API response into PRInfo.""" + # Parse dates + created_at = self._parse_datetime(data.get("created_at")) + updated_at = self._parse_datetime(data.get("updated_at")) + merged_at = self._parse_datetime(data.get("merged_at")) + closed_at = self._parse_datetime(data.get("closed_at")) + + # Determine state + if data.get("merged_at"): + state = PRState.MERGED + elif data.get("state") == "closed": + state = PRState.CLOSED + else: + state = PRState.OPEN + + # Extract labels + labels = [lbl["name"] for lbl in data.get("labels", [])] + + # Extract assignees + assignees = [a["login"] for a in data.get("assignees", [])] + + # Extract reviewers + reviewers = [] + if "requested_reviewers" in data: + reviewers = [r["login"] for r in data["requested_reviewers"]] + + return PRInfo( + number=data["number"], + title=data["title"], + body=data.get("body", "") or "", + state=state, + source_branch=data.get("head", {}).get("ref", ""), + target_branch=data.get("base", {}).get("ref", ""), + author=data.get("user", {}).get("login", ""), + created_at=created_at, + updated_at=updated_at, + merged_at=merged_at, + closed_at=closed_at, + url=data.get("html_url"), + labels=labels, + assignees=assignees, + reviewers=reviewers, + mergeable=data.get("mergeable"), + draft=data.get("draft", False), + ) + + def _parse_datetime(self, value: str | None) -> datetime: + """Parse datetime string from API.""" + if not value: + return datetime.now(UTC) + + try: + # GitHub uses ISO 8601 format with Z suffix + if value.endswith("Z"): + value = value[:-1] + "+00:00" + return datetime.fromisoformat(value) + except ValueError: + return datetime.now(UTC) diff --git a/mcp-servers/git-ops/server.py b/mcp-servers/git-ops/server.py index 69b7326..fbc29ca 100644 --- a/mcp-servers/git-ops/server.py +++ b/mcp-servers/git-ops/server.py @@ -22,7 +22,7 @@ from config import Settings, get_settings from exceptions import ErrorCode, GitOpsError from git_wrapper import GitWrapper from models import MergeStrategy, PRState -from providers import GiteaProvider +from providers import BaseProvider, GiteaProvider, GitHubProvider from workspace import WorkspaceManager # Input validation patterns @@ -77,25 +77,57 @@ logger = logging.getLogger(__name__) _settings: Settings | None = None _workspace_manager: WorkspaceManager | None = None _gitea_provider: GiteaProvider | None = None +_github_provider: GitHubProvider | None = None -def _get_provider_for_url(repo_url: str) -> GiteaProvider | None: +def _get_provider_for_url(repo_url: str) -> BaseProvider | None: """Get the appropriate provider for a repository URL.""" if not _settings: return None + # Normalize URL for matching + url_lower = repo_url.lower() + + # Check for GitHub URLs + if "github.com" in url_lower or ( + _settings.github_api_url + and _settings.github_api_url != "https://api.github.com" + and _settings.github_api_url.replace("https://", "").replace("/api/v3", "") in url_lower + ): + return _github_provider + # Check if it's a Gitea URL - if _settings.gitea_base_url and _settings.gitea_base_url in repo_url: + if _settings.gitea_base_url and _settings.gitea_base_url.lower() in url_lower: return _gitea_provider - # Default to Gitea for now + # Default: try to detect from URL pattern + # If URL contains 'github' anywhere, use GitHub + if "github" in url_lower: + return _github_provider + + # Default to Gitea for self-hosted instances return _gitea_provider +def _get_auth_token_for_url(repo_url: str) -> str | None: + """Get the appropriate auth token for a repository URL.""" + if not _settings: + return None + + url_lower = repo_url.lower() + + # GitHub token + if "github.com" in url_lower or "github" in url_lower: + return _settings.github_token if _settings.github_token else None + + # Gitea token (default) + return _settings.gitea_token if _settings.gitea_token else None + + @asynccontextmanager async def lifespan(_app: FastAPI) -> AsyncIterator[None]: """Application lifespan handler.""" - global _settings, _workspace_manager, _gitea_provider + global _settings, _workspace_manager, _gitea_provider, _github_provider logger.info("Starting Git Operations MCP Server...") @@ -114,6 +146,13 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: ) logger.info(f"Gitea provider initialized: {_settings.gitea_base_url}") + if _settings.github_token: + _github_provider = GitHubProvider( + token=_settings.github_token, + settings=_settings, + ) + logger.info("GitHub provider initialized") + logger.info("Git Operations MCP Server started successfully") yield @@ -124,6 +163,9 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: if _gitea_provider: await _gitea_provider.close() + if _github_provider: + await _github_provider.close() + logger.info("Git Operations MCP Server shut down") @@ -167,6 +209,21 @@ async def health_check() -> dict[str, Any]: else: status["dependencies"]["gitea"] = "not configured" + # Check GitHub connectivity + if _github_provider: + try: + if await _github_provider.is_connected(): + user = await _github_provider.get_authenticated_user() + status["dependencies"]["github"] = f"connected as {user}" + else: + status["dependencies"]["github"] = "not connected" + is_degraded = True + except Exception as e: + status["dependencies"]["github"] = f"error: {e}" + is_degraded = True + else: + status["dependencies"]["github"] = "not configured" + # Check workspace directory if _workspace_manager: try: @@ -442,10 +499,8 @@ async def clone_repository( # Create workspace workspace = await _workspace_manager.create_workspace(project_id, repo_url) # type: ignore[union-attr] - # Get auth token from provider - auth_token = None - if _settings and _settings.gitea_token: - auth_token = _settings.gitea_token + # Get auth token appropriate for the repository URL + auth_token = _get_auth_token_for_url(repo_url) # Clone repository git = GitWrapper(workspace.path, _settings) @@ -715,10 +770,8 @@ async def push( if not workspace: return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value} - # Get auth token - auth_token = None - if _settings and _settings.gitea_token: - auth_token = _settings.gitea_token + # Get auth token appropriate for the repository URL + auth_token = _get_auth_token_for_url(workspace.repo_url or "") git = GitWrapper(workspace.path, _settings) result = await git.push( diff --git a/mcp-servers/git-ops/tests/test_github_provider.py b/mcp-servers/git-ops/tests/test_github_provider.py new file mode 100644 index 0000000..40cb623 --- /dev/null +++ b/mcp-servers/git-ops/tests/test_github_provider.py @@ -0,0 +1,583 @@ +""" +Tests for GitHub provider implementation. +""" + +from unittest.mock import MagicMock + +import pytest + +from exceptions import APIError, AuthenticationError +from models import MergeStrategy, PRState +from providers.github import GitHubProvider + + +class TestGitHubProviderBasics: + """Tests for GitHubProvider basic operations.""" + + def test_provider_name(self): + """Test provider name is github.""" + provider = GitHubProvider(token="test-token") + assert provider.name == "github" + + def test_parse_repo_url_https(self): + """Test parsing HTTPS repo URL.""" + provider = GitHubProvider(token="test-token") + + owner, repo = provider.parse_repo_url("https://github.com/owner/repo.git") + + assert owner == "owner" + assert repo == "repo" + + def test_parse_repo_url_https_no_git(self): + """Test parsing HTTPS URL without .git suffix.""" + provider = GitHubProvider(token="test-token") + + owner, repo = provider.parse_repo_url("https://github.com/owner/repo") + + assert owner == "owner" + assert repo == "repo" + + def test_parse_repo_url_ssh(self): + """Test parsing SSH repo URL.""" + provider = GitHubProvider(token="test-token") + + owner, repo = provider.parse_repo_url("git@github.com:owner/repo.git") + + assert owner == "owner" + assert repo == "repo" + + def test_parse_repo_url_invalid(self): + """Test error on invalid URL.""" + provider = GitHubProvider(token="test-token") + + with pytest.raises(ValueError, match="Unable to parse"): + provider.parse_repo_url("invalid-url") + + +@pytest.fixture +def mock_github_httpx_client(): + """Create a mock httpx client for GitHub provider tests.""" + from unittest.mock import AsyncMock + + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value={}) + mock_response.text = "" + + mock_client = AsyncMock() + mock_client.request = AsyncMock(return_value=mock_response) + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.patch = AsyncMock(return_value=mock_response) + mock_client.put = AsyncMock(return_value=mock_response) + mock_client.delete = AsyncMock(return_value=mock_response) + + return mock_client + + +@pytest.fixture +async def github_provider(test_settings, mock_github_httpx_client): + """Create a GitHubProvider with mocked HTTP client.""" + provider = GitHubProvider( + token=test_settings.github_token, + settings=test_settings, + ) + provider._client = mock_github_httpx_client + + yield provider + + await provider.close() + + +@pytest.fixture +def github_pr_data(): + """Sample PR data from GitHub API.""" + return { + "number": 42, + "title": "Test PR", + "body": "This is a test pull request", + "state": "open", + "head": {"ref": "feature-branch"}, + "base": {"ref": "main"}, + "user": {"login": "test-user"}, + "created_at": "2024-01-15T10:00:00Z", + "updated_at": "2024-01-15T12:00:00Z", + "merged_at": None, + "closed_at": None, + "html_url": "https://github.com/owner/repo/pull/42", + "labels": [{"name": "enhancement"}], + "assignees": [{"login": "assignee1"}], + "requested_reviewers": [{"login": "reviewer1"}], + "mergeable": True, + "draft": False, + } + + +class TestGitHubProviderConnection: + """Tests for GitHub provider connection.""" + + @pytest.mark.asyncio + async def test_is_connected(self, github_provider, mock_github_httpx_client): + """Test connection check.""" + mock_github_httpx_client.request.return_value.json = MagicMock( + return_value={"login": "test-user"} + ) + + result = await github_provider.is_connected() + + assert result is True + + @pytest.mark.asyncio + async def test_is_connected_no_token(self, test_settings): + """Test connection fails without token.""" + provider = GitHubProvider( + token="", + settings=test_settings, + ) + + result = await provider.is_connected() + assert result is False + + await provider.close() + + @pytest.mark.asyncio + async def test_get_authenticated_user(self, github_provider, mock_github_httpx_client): + """Test getting authenticated user.""" + mock_github_httpx_client.request.return_value.json = MagicMock( + return_value={"login": "test-user"} + ) + + user = await github_provider.get_authenticated_user() + + assert user == "test-user" + + +class TestGitHubProviderRepoOperations: + """Tests for GitHub repository operations.""" + + @pytest.mark.asyncio + async def test_get_repo_info(self, github_provider, mock_github_httpx_client): + """Test getting repository info.""" + mock_github_httpx_client.request.return_value.json = MagicMock( + return_value={ + "name": "repo", + "full_name": "owner/repo", + "default_branch": "main", + } + ) + + result = await github_provider.get_repo_info("owner", "repo") + + assert result["name"] == "repo" + assert result["default_branch"] == "main" + + @pytest.mark.asyncio + async def test_get_default_branch(self, github_provider, mock_github_httpx_client): + """Test getting default branch.""" + mock_github_httpx_client.request.return_value.json = MagicMock( + return_value={"default_branch": "develop"} + ) + + branch = await github_provider.get_default_branch("owner", "repo") + + assert branch == "develop" + + +class TestGitHubPROperations: + """Tests for GitHub PR operations.""" + + @pytest.mark.asyncio + async def test_create_pr(self, github_provider, mock_github_httpx_client): + """Test creating a pull request.""" + mock_github_httpx_client.request.return_value.json = MagicMock( + return_value={ + "number": 42, + "html_url": "https://github.com/owner/repo/pull/42", + } + ) + + result = await github_provider.create_pr( + owner="owner", + repo="repo", + title="Test PR", + body="Test body", + source_branch="feature", + target_branch="main", + ) + + assert result.success is True + assert result.pr_number == 42 + assert result.pr_url == "https://github.com/owner/repo/pull/42" + + @pytest.mark.asyncio + async def test_create_pr_with_draft(self, github_provider, mock_github_httpx_client): + """Test creating a draft PR.""" + mock_github_httpx_client.request.return_value.json = MagicMock( + return_value={ + "number": 43, + "html_url": "https://github.com/owner/repo/pull/43", + } + ) + + result = await github_provider.create_pr( + owner="owner", + repo="repo", + title="Draft PR", + body="Draft body", + source_branch="feature", + target_branch="main", + draft=True, + ) + + assert result.success is True + assert result.pr_number == 43 + + @pytest.mark.asyncio + async def test_create_pr_with_options(self, github_provider, mock_github_httpx_client): + """Test creating PR with labels, assignees, reviewers.""" + mock_responses = [ + {"number": 44, "html_url": "https://github.com/owner/repo/pull/44"}, # Create PR + [{"name": "enhancement"}], # POST add labels + {}, # POST add assignees + {}, # POST request reviewers + ] + mock_github_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses) + + result = await github_provider.create_pr( + owner="owner", + repo="repo", + title="Test PR", + body="Test body", + source_branch="feature", + target_branch="main", + labels=["enhancement"], + assignees=["user1"], + reviewers=["reviewer1"], + ) + + assert result.success is True + + @pytest.mark.asyncio + async def test_get_pr(self, github_provider, mock_github_httpx_client, github_pr_data): + """Test getting a pull request.""" + mock_github_httpx_client.request.return_value.json = MagicMock( + return_value=github_pr_data + ) + + result = await github_provider.get_pr("owner", "repo", 42) + + assert result.success is True + assert result.pr["number"] == 42 + assert result.pr["title"] == "Test PR" + + @pytest.mark.asyncio + async def test_get_pr_not_found(self, github_provider, mock_github_httpx_client): + """Test getting non-existent PR.""" + mock_github_httpx_client.request.return_value.status_code = 404 + mock_github_httpx_client.request.return_value.json = MagicMock(return_value=None) + + result = await github_provider.get_pr("owner", "repo", 999) + + assert result.success is False + + @pytest.mark.asyncio + async def test_list_prs(self, github_provider, mock_github_httpx_client, github_pr_data): + """Test listing pull requests.""" + mock_github_httpx_client.request.return_value.json = MagicMock( + return_value=[github_pr_data, github_pr_data] + ) + + result = await github_provider.list_prs("owner", "repo") + + assert result.success is True + assert len(result.pull_requests) == 2 + + @pytest.mark.asyncio + async def test_list_prs_with_state_filter(self, github_provider, mock_github_httpx_client, github_pr_data): + """Test listing PRs with state filter.""" + mock_github_httpx_client.request.return_value.json = MagicMock( + return_value=[github_pr_data] + ) + + result = await github_provider.list_prs( + "owner", "repo", state=PRState.OPEN + ) + + assert result.success is True + + @pytest.mark.asyncio + async def test_merge_pr(self, github_provider, mock_github_httpx_client, github_pr_data): + """Test merging a pull request.""" + # Merge returns sha, then get_pr returns the PR data, then delete branch + mock_responses = [ + {"sha": "merge-commit-sha", "merged": True}, # PUT merge + github_pr_data, # GET PR for branch info + None, # DELETE branch + ] + mock_github_httpx_client.request.return_value.json = MagicMock( + side_effect=mock_responses + ) + + result = await github_provider.merge_pr( + "owner", "repo", 42, + merge_strategy=MergeStrategy.SQUASH, + ) + + assert result.success is True + assert result.merge_commit_sha == "merge-commit-sha" + + @pytest.mark.asyncio + async def test_merge_pr_rebase(self, github_provider, mock_github_httpx_client, github_pr_data): + """Test merging with rebase strategy.""" + mock_responses = [ + {"sha": "rebase-commit-sha", "merged": True}, # PUT merge + github_pr_data, # GET PR for branch info + None, # DELETE branch + ] + mock_github_httpx_client.request.return_value.json = MagicMock( + side_effect=mock_responses + ) + + result = await github_provider.merge_pr( + "owner", "repo", 42, + merge_strategy=MergeStrategy.REBASE, + ) + + assert result.success is True + + @pytest.mark.asyncio + async def test_update_pr(self, github_provider, mock_github_httpx_client, github_pr_data): + """Test updating a pull request.""" + mock_github_httpx_client.request.return_value.json = MagicMock( + return_value=github_pr_data + ) + + result = await github_provider.update_pr( + "owner", "repo", 42, + title="Updated Title", + body="Updated body", + ) + + assert result.success is True + + @pytest.mark.asyncio + async def test_close_pr(self, github_provider, mock_github_httpx_client, github_pr_data): + """Test closing a pull request.""" + github_pr_data["state"] = "closed" + mock_github_httpx_client.request.return_value.json = MagicMock( + return_value=github_pr_data + ) + + result = await github_provider.close_pr("owner", "repo", 42) + + assert result.success is True + + +class TestGitHubBranchOperations: + """Tests for GitHub branch operations.""" + + @pytest.mark.asyncio + async def test_get_branch(self, github_provider, mock_github_httpx_client): + """Test getting branch info.""" + mock_github_httpx_client.request.return_value.json = MagicMock( + return_value={ + "name": "main", + "commit": {"sha": "abc123"}, + } + ) + + result = await github_provider.get_branch("owner", "repo", "main") + + assert result["name"] == "main" + + @pytest.mark.asyncio + async def test_delete_remote_branch(self, github_provider, mock_github_httpx_client): + """Test deleting a remote branch.""" + mock_github_httpx_client.request.return_value.status_code = 204 + + result = await github_provider.delete_remote_branch("owner", "repo", "old-branch") + + assert result is True + + +class TestGitHubCommentOperations: + """Tests for GitHub comment operations.""" + + @pytest.mark.asyncio + async def test_add_pr_comment(self, github_provider, mock_github_httpx_client): + """Test adding a comment to a PR.""" + mock_github_httpx_client.request.return_value.json = MagicMock( + return_value={"id": 1, "body": "Test comment"} + ) + + result = await github_provider.add_pr_comment( + "owner", "repo", 42, "Test comment" + ) + + assert result["body"] == "Test comment" + + @pytest.mark.asyncio + async def test_list_pr_comments(self, github_provider, mock_github_httpx_client): + """Test listing PR comments.""" + mock_github_httpx_client.request.return_value.json = MagicMock( + return_value=[ + {"id": 1, "body": "Comment 1"}, + {"id": 2, "body": "Comment 2"}, + ] + ) + + result = await github_provider.list_pr_comments("owner", "repo", 42) + + assert len(result) == 2 + + +class TestGitHubLabelOperations: + """Tests for GitHub label operations.""" + + @pytest.mark.asyncio + async def test_add_labels(self, github_provider, mock_github_httpx_client): + """Test adding labels to a PR.""" + mock_github_httpx_client.request.return_value.json = MagicMock( + return_value=[{"name": "bug"}, {"name": "urgent"}] + ) + + result = await github_provider.add_labels( + "owner", "repo", 42, ["bug", "urgent"] + ) + + assert "bug" in result + assert "urgent" in result + + @pytest.mark.asyncio + async def test_remove_label(self, github_provider, mock_github_httpx_client): + """Test removing a label from a PR.""" + mock_responses = [ + None, # DELETE label + {"labels": []}, # GET issue + ] + mock_github_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses) + + result = await github_provider.remove_label( + "owner", "repo", 42, "bug" + ) + + assert isinstance(result, list) + + +class TestGitHubReviewerOperations: + """Tests for GitHub reviewer operations.""" + + @pytest.mark.asyncio + async def test_request_review(self, github_provider, mock_github_httpx_client): + """Test requesting review from users.""" + mock_github_httpx_client.request.return_value.json = MagicMock(return_value={}) + + result = await github_provider.request_review( + "owner", "repo", 42, ["reviewer1", "reviewer2"] + ) + + assert result == ["reviewer1", "reviewer2"] + + +class TestGitHubErrorHandling: + """Tests for error handling in GitHub provider.""" + + @pytest.mark.asyncio + async def test_authentication_error(self, github_provider, mock_github_httpx_client): + """Test handling authentication errors.""" + mock_github_httpx_client.request.return_value.status_code = 401 + + with pytest.raises(AuthenticationError): + await github_provider._request("GET", "/user") + + @pytest.mark.asyncio + async def test_permission_denied(self, github_provider, mock_github_httpx_client): + """Test handling permission denied errors.""" + mock_github_httpx_client.request.return_value.status_code = 403 + mock_github_httpx_client.request.return_value.text = "Permission denied" + + with pytest.raises(AuthenticationError, match="Insufficient permissions"): + await github_provider._request("GET", "/protected") + + @pytest.mark.asyncio + async def test_rate_limit_error(self, github_provider, mock_github_httpx_client): + """Test handling rate limit errors.""" + mock_github_httpx_client.request.return_value.status_code = 403 + mock_github_httpx_client.request.return_value.text = "API rate limit exceeded" + + with pytest.raises(APIError, match="rate limit"): + await github_provider._request("GET", "/user") + + @pytest.mark.asyncio + async def test_api_error(self, github_provider, mock_github_httpx_client): + """Test handling general API errors.""" + mock_github_httpx_client.request.return_value.status_code = 500 + mock_github_httpx_client.request.return_value.text = "Internal Server Error" + mock_github_httpx_client.request.return_value.json = MagicMock( + return_value={"message": "Server error"} + ) + + with pytest.raises(APIError): + await github_provider._request("GET", "/error") + + +class TestGitHubPRParsing: + """Tests for PR data parsing.""" + + def test_parse_pr_open(self, github_provider, github_pr_data): + """Test parsing open PR.""" + pr_info = github_provider._parse_pr(github_pr_data) + + assert pr_info.number == 42 + assert pr_info.state == PRState.OPEN + assert pr_info.title == "Test PR" + assert pr_info.source_branch == "feature-branch" + assert pr_info.target_branch == "main" + + def test_parse_pr_merged(self, github_provider, github_pr_data): + """Test parsing merged PR.""" + github_pr_data["merged_at"] = "2024-01-16T10:00:00Z" + + pr_info = github_provider._parse_pr(github_pr_data) + + assert pr_info.state == PRState.MERGED + + def test_parse_pr_closed(self, github_provider, github_pr_data): + """Test parsing closed PR.""" + github_pr_data["state"] = "closed" + github_pr_data["closed_at"] = "2024-01-16T10:00:00Z" + + pr_info = github_provider._parse_pr(github_pr_data) + + assert pr_info.state == PRState.CLOSED + + def test_parse_pr_draft(self, github_provider, github_pr_data): + """Test parsing draft PR.""" + github_pr_data["draft"] = True + + pr_info = github_provider._parse_pr(github_pr_data) + + assert pr_info.draft is True + + def test_parse_datetime_iso(self, github_provider): + """Test parsing ISO datetime strings.""" + dt = github_provider._parse_datetime("2024-01-15T10:30:00Z") + + assert dt.year == 2024 + assert dt.month == 1 + assert dt.day == 15 + + def test_parse_datetime_none(self, github_provider): + """Test parsing None datetime returns now.""" + dt = github_provider._parse_datetime(None) + + assert dt is not None + assert dt.tzinfo is not None + + def test_parse_pr_with_null_body(self, github_provider, github_pr_data): + """Test parsing PR with null body.""" + github_pr_data["body"] = None + + pr_info = github_provider._parse_pr(github_pr_data) + + assert pr_info.body == ""