""" Git operations wrapper using GitPython. Provides high-level git operations with proper error handling, async compatibility, and structured results. """ import asyncio import logging import os import re from concurrent.futures import ThreadPoolExecutor from datetime import UTC, datetime from functools import partial from pathlib import Path from typing import Any from git import GitCommandError, InvalidGitRepositoryError, NoSuchPathError from git import Repo as GitRepo from config import Settings, get_settings from exceptions import ( BranchExistsError, BranchNotFoundError, CheckoutError, CloneError, CommitError, DirtyWorkspaceError, GitError, MergeConflictError, PullError, PushError, ) from models import ( BranchInfo, BranchResult, CheckoutResult, CloneResult, CommitInfo, CommitResult, DiffHunk, DiffResult, FileChange, FileChangeType, FileDiff, ListBranchesResult, LogResult, PullResult, PushResult, StatusResult, ) logger = logging.getLogger(__name__) def sanitize_url_for_logging(url: str) -> str: """ Remove any credentials from a URL before logging. Handles URLs like: - https://token@github.com/owner/repo.git - https://user:password@github.com/owner/repo.git - git@github.com:owner/repo.git (unchanged, no credentials) """ # Pattern to match https://[credentials@]host/path sanitized = re.sub(r"(https?://)([^@]+@)", r"\1***@", url) return sanitized # Thread pool for blocking git operations _executor: ThreadPoolExecutor | None = None def get_executor() -> ThreadPoolExecutor: """Get the shared thread pool executor.""" global _executor if _executor is None: _executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="git-ops-") return _executor async def run_in_executor(func: Any, *args: Any, **kwargs: Any) -> Any: """Run a blocking function in the thread pool.""" loop = asyncio.get_event_loop() executor = get_executor() partial_func = partial(func, *args, **kwargs) return await loop.run_in_executor(executor, partial_func) class GitWrapper: """ Wrapper for git operations using GitPython. Provides async-compatible git operations with proper error handling. """ def __init__( self, workspace_path: Path | str, settings: Settings | None = None, ) -> None: """ Initialize GitWrapper. Args: workspace_path: Path to the git workspace settings: Optional settings override """ self.workspace_path = ( Path(workspace_path) if isinstance(workspace_path, str) else workspace_path ) self.settings = settings or get_settings() self._repo: GitRepo | None = None @property def repo(self) -> GitRepo: """Get the GitPython Repo instance.""" if self._repo is None: try: self._repo = GitRepo(self.workspace_path) except InvalidGitRepositoryError: raise GitError(f"Not a git repository: {self.workspace_path}") except NoSuchPathError: raise GitError(f"Path does not exist: {self.workspace_path}") return self._repo def _refresh_repo(self) -> None: """Refresh the repo instance after operations that change it.""" self._repo = None # Clone operations async def clone( self, repo_url: str, branch: str | None = None, depth: int | None = None, auth_token: str | None = None, ) -> CloneResult: """ Clone a repository. Args: repo_url: Repository URL to clone branch: Branch to checkout after clone depth: Shallow clone depth (None for full) auth_token: Optional auth token for HTTPS Returns: CloneResult with clone status """ def _do_clone() -> CloneResult: try: # Build clone URL with auth if provided clone_url = repo_url if auth_token and repo_url.startswith("https://"): # Insert token in URL: https://token@host/path clone_url = re.sub( r"^(https://)(.+)$", rf"\1{auth_token}@\2", repo_url, ) # Build clone arguments kwargs: dict[str, Any] = { "url": clone_url, "to_path": str(self.workspace_path), } if branch: kwargs["branch"] = branch if depth: kwargs["depth"] = depth # Set environment for auth env = os.environ.copy() env["GIT_TERMINAL_PROMPT"] = "0" logger.info(f"Cloning repository: {repo_url} -> {self.workspace_path}") repo = GitRepo.clone_from(**kwargs) self._repo = repo return CloneResult( success=True, project_id="", # Set by caller workspace_path=str(self.workspace_path), branch=repo.active_branch.name, commit_sha=repo.head.commit.hexsha, ) except GitCommandError as e: # Sanitize URLs in error messages to prevent credential leakage error_msg = sanitize_url_for_logging(str(e)) logger.error(f"Clone failed: {error_msg}") raise CloneError(sanitize_url_for_logging(repo_url), error_msg) return await run_in_executor(_do_clone) # Status operations async def status(self, include_untracked: bool = True) -> StatusResult: """ Get git status. Args: include_untracked: Include untracked files Returns: StatusResult with working tree status """ def _get_status() -> StatusResult: repo = self.repo # Get staged changes staged = [] for diff in repo.index.diff("HEAD"): change_type = self._diff_to_change_type(diff.change_type) path = diff.b_path or diff.a_path or "" staged.append( FileChange( path=path, change_type=change_type, old_path=diff.a_path if diff.renamed else None, ).to_dict() ) # Get unstaged changes unstaged = [] for diff in repo.index.diff(None): change_type = self._diff_to_change_type(diff.change_type) path = diff.b_path or diff.a_path or "" unstaged.append( FileChange( path=path, change_type=change_type, ).to_dict() ) # Get untracked files untracked = list(repo.untracked_files) if include_untracked else [] # Get tracking info ahead = behind = 0 try: tracking = repo.active_branch.tracking_branch() if tracking: ahead = len( list( repo.iter_commits( f"{tracking.name}..{repo.active_branch.name}" ) ) ) behind = len( list( repo.iter_commits( f"{repo.active_branch.name}..{tracking.name}" ) ) ) except Exception: pass # No tracking branch is_clean = len(staged) == 0 and len(unstaged) == 0 and len(untracked) == 0 return StatusResult( project_id="", # Set by caller branch=repo.active_branch.name, commit_sha=repo.head.commit.hexsha, is_clean=is_clean, staged=staged, unstaged=unstaged, untracked=untracked, ahead=ahead, behind=behind, ) return await run_in_executor(_get_status) # Branch operations async def create_branch( self, branch_name: str, from_ref: str | None = None, checkout: bool = True, ) -> BranchResult: """ Create a new branch. Args: branch_name: Name for the new branch from_ref: Reference to create from (default: HEAD) checkout: Whether to checkout after creation Returns: BranchResult with creation status """ def _create_branch() -> BranchResult: repo = self.repo # Check if branch already exists if branch_name in [b.name for b in repo.branches]: raise BranchExistsError(branch_name) try: # Get the starting point if from_ref: start_point = repo.commit(from_ref) else: start_point = repo.head.commit # Create branch new_branch = repo.create_head(branch_name, start_point) # Checkout if requested if checkout: new_branch.checkout() return BranchResult( success=True, branch=branch_name, commit_sha=new_branch.commit.hexsha, is_current=checkout, ) except GitCommandError as e: logger.error(f"Failed to create branch {branch_name}: {e}") raise GitError(f"Failed to create branch: {e}") return await run_in_executor(_create_branch) async def delete_branch( self, branch_name: str, force: bool = False, ) -> BranchResult: """ Delete a branch. Args: branch_name: Branch to delete force: Force delete even if not merged Returns: BranchResult with deletion status """ def _delete_branch() -> BranchResult: repo = self.repo if branch_name not in [b.name for b in repo.branches]: raise BranchNotFoundError(branch_name) if repo.active_branch.name == branch_name: raise GitError(f"Cannot delete current branch: {branch_name}") try: repo.delete_head(branch_name, force=force) return BranchResult( success=True, branch=branch_name, is_current=False, ) except GitCommandError as e: logger.error(f"Failed to delete branch {branch_name}: {e}") raise GitError(f"Failed to delete branch: {e}") return await run_in_executor(_delete_branch) async def list_branches(self, include_remote: bool = False) -> ListBranchesResult: """ List branches. Args: include_remote: Include remote tracking branches Returns: ListBranchesResult with branch lists """ def _list_branches() -> ListBranchesResult: repo = self.repo local_branches = [] for branch in repo.branches: tracking = branch.tracking_branch() msg = branch.commit.message commit_msg = ( msg.decode("utf-8", errors="replace") if isinstance(msg, bytes) else msg ).split("\n")[0] local_branches.append( BranchInfo( name=branch.name, is_current=branch == repo.active_branch, is_remote=False, tracking_branch=tracking.name if tracking else None, commit_sha=branch.commit.hexsha, commit_message=commit_msg, ).to_dict() ) remote_branches = [] if include_remote: for remote in repo.remotes: for ref in remote.refs: # Skip HEAD refs if ref.name.endswith("/HEAD"): continue msg = ref.commit.message commit_msg = ( msg.decode("utf-8", errors="replace") if isinstance(msg, bytes) else msg ).split("\n")[0] remote_branches.append( BranchInfo( name=ref.name, is_current=False, is_remote=True, commit_sha=ref.commit.hexsha, commit_message=commit_msg, ).to_dict() ) return ListBranchesResult( project_id="", # Set by caller current_branch=repo.active_branch.name, local_branches=local_branches, remote_branches=remote_branches, ) return await run_in_executor(_list_branches) async def checkout( self, ref: str, create_branch: bool = False, force: bool = False, ) -> CheckoutResult: """ Checkout a branch or ref. Args: ref: Branch, tag, or commit to checkout create_branch: Create new branch with this name force: Force checkout (discard local changes) Returns: CheckoutResult with checkout status """ def _checkout() -> CheckoutResult: repo = self.repo try: if create_branch: # Create and checkout new branch if ref in [b.name for b in repo.branches]: raise BranchExistsError(ref) new_branch = repo.create_head(ref) new_branch.checkout(force=force) else: # Checkout existing ref if ref in [b.name for b in repo.branches]: # Local branch repo.heads[ref].checkout(force=force) else: # Try as a commit/tag repo.git.checkout(ref, force=force) return CheckoutResult( success=True, ref=ref, commit_sha=repo.head.commit.hexsha, ) except GitCommandError as e: error_msg = str(e) if "would be overwritten" in error_msg: raise DirtyWorkspaceError([]) raise CheckoutError(ref, error_msg) return await run_in_executor(_checkout) # Commit operations async def commit( self, message: str, files: list[str] | None = None, author_name: str | None = None, author_email: str | None = None, allow_empty: bool = False, ) -> CommitResult: """ Create a commit. Args: message: Commit message files: Specific files to commit (None = all staged) author_name: Author name override author_email: Author email override allow_empty: Allow empty commits Returns: CommitResult with commit info """ def _commit() -> CommitResult: repo = self.repo try: # Stage files if specified if files: repo.index.add(files) elif not allow_empty: # Stage all modified/deleted repo.git.add("-A") # Check if there's anything to commit if ( not allow_empty and not repo.index.diff("HEAD") and not repo.untracked_files ): raise CommitError("Nothing to commit") # Build author author = None if author_name and author_email: from git import Actor author = Actor(author_name, author_email) elif author_name or author_email: from git import Actor author = Actor( author_name or self.settings.git_author_name, author_email or self.settings.git_author_email, ) # Create commit commit = repo.index.commit( message, author=author, committer=author, ) # Get stats stats = commit.stats.total files_changed = stats.get("files", 0) insertions = stats.get("insertions", 0) deletions = stats.get("deletions", 0) return CommitResult( success=True, commit_sha=commit.hexsha, short_sha=commit.hexsha[:7], message=message, files_changed=files_changed, insertions=insertions, deletions=deletions, ) except GitCommandError as e: logger.error(f"Commit failed: {e}") raise CommitError(str(e)) return await run_in_executor(_commit) async def stage(self, files: list[str] | None = None) -> int: """ Stage files for commit. Args: files: Files to stage (None = all) Returns: Number of files staged """ def _stage() -> int: repo = self.repo if files: repo.index.add(files) return len(files) else: repo.git.add("-A") return len(repo.index.diff("HEAD")) + len(repo.untracked_files) return await run_in_executor(_stage) async def unstage(self, files: list[str] | None = None) -> int: """ Unstage files. Args: files: Files to unstage (None = all) Returns: Number of files unstaged """ def _unstage() -> int: repo = self.repo staged = list(repo.index.diff("HEAD")) if files: repo.index.remove(files, working_tree=False) return len(files) else: # Unstage all if staged: repo.git.reset("HEAD") return len(staged) return await run_in_executor(_unstage) # Push/Pull operations async def push( self, branch: str | None = None, remote: str = "origin", force: bool = False, set_upstream: bool = True, auth_token: str | None = None, ) -> PushResult: """ Push to remote. Args: branch: Branch to push (None = current) remote: Remote name force: Force push set_upstream: Set upstream tracking auth_token: Auth token for HTTPS Returns: PushResult with push status """ def _push() -> PushResult: repo = self.repo push_branch = branch or repo.active_branch.name # Check force push policy if force and not self.settings.enable_force_push: raise PushError(push_branch, "Force push is disabled") try: # Build push info push_info_list: list[Any] = [] if remote not in [r.name for r in repo.remotes]: raise PushError(push_branch, f"Remote not found: {remote}") remote_obj = repo.remote(remote) # Configure auth if provided if auth_token: # Set credential helper temporarily pass # TODO: Implement token-based auth # Build refspec refspec = f"{push_branch}:{push_branch}" if set_upstream: push_info_list = remote_obj.push( refspec=refspec, force=force, set_upstream=True, ) else: push_info_list = remote_obj.push( refspec=refspec, force=force, ) # Check for errors for info in push_info_list: if info.flags & info.ERROR: raise PushError(push_branch, info.summary) # Count commits pushed (approximate) commits_pushed = 0 try: tracking = repo.active_branch.tracking_branch() if tracking: commits_pushed = len( list(repo.iter_commits(f"{tracking.name}..{push_branch}")) ) except Exception: pass return PushResult( success=True, branch=push_branch, remote=remote, commits_pushed=commits_pushed, ) except GitCommandError as e: error_msg = str(e) if "rejected" in error_msg: raise PushError( push_branch, "Push rejected - pull and merge first or force push", ) raise PushError(push_branch, error_msg) return await run_in_executor(_push) async def pull( self, branch: str | None = None, remote: str = "origin", rebase: bool = False, auth_token: str | None = None, ) -> PullResult: """ Pull from remote. Args: branch: Branch to pull (None = current) remote: Remote name rebase: Rebase instead of merge auth_token: Auth token for HTTPS Returns: PullResult with pull status """ def _pull() -> PullResult: repo = self.repo pull_branch = branch or repo.active_branch.name try: if remote not in [r.name for r in repo.remotes]: raise PullError(pull_branch, f"Remote not found: {remote}") remote_obj = repo.remote(remote) # Fetch first to check for conflicts remote_obj.fetch() # Get commits before pull head_before = repo.head.commit.hexsha # Perform pull if rebase: repo.git.pull("--rebase", remote, pull_branch) else: repo.git.pull(remote, pull_branch) # Count new commits commits_received = len(list(repo.iter_commits(f"{head_before}..HEAD"))) # Check if fast-forward fast_forward = commits_received > 0 and not repo.head.commit.parents return PullResult( success=True, branch=pull_branch, commits_received=commits_received, fast_forward=fast_forward, ) except GitCommandError as e: error_msg = str(e) if "conflict" in error_msg.lower(): # Get conflicting files - keys are paths directly conflicts = [ str(path) for path in repo.index.unmerged_blobs().keys() ] raise MergeConflictError(conflicts) raise PullError(pull_branch, error_msg) return await run_in_executor(_pull) async def fetch( self, remote: str = "origin", prune: bool = False, ) -> bool: """ Fetch from remote. Args: remote: Remote name prune: Prune deleted remote branches Returns: True if successful """ def _fetch() -> bool: repo = self.repo try: if remote not in [r.name for r in repo.remotes]: raise GitError(f"Remote not found: {remote}") remote_obj = repo.remote(remote) remote_obj.fetch(prune=prune) return True except GitCommandError as e: logger.error(f"Fetch failed: {e}") raise GitError(f"Fetch failed: {e}") return await run_in_executor(_fetch) # Diff operations async def diff( self, base: str | None = None, head: str | None = None, files: list[str] | None = None, context_lines: int = 3, ) -> DiffResult: """ Get diff between refs. Args: base: Base reference (None = working tree) head: Head reference (None = HEAD) files: Specific files to diff context_lines: Context lines to include Returns: DiffResult with diff info """ def _diff() -> DiffResult: repo = self.repo file_diffs = [] total_additions = 0 total_deletions = 0 try: # Determine what to diff if base is None and head is None: # Working tree vs staged diffs = repo.index.diff(None, create_patch=True) elif base is None: # Working tree vs specified ref diffs = repo.commit(head).diff(None, create_patch=True) elif head is None: # Specified ref vs HEAD diffs = repo.commit(base).diff("HEAD", create_patch=True) else: # Between two refs diffs = repo.commit(base).diff(head, create_patch=True) for diff in diffs: # Filter by files if specified if files and diff.a_path not in files and diff.b_path not in files: continue change_type = self._diff_to_change_type(diff.change_type) path = diff.b_path or diff.a_path or "" # Parse hunks from patch hunks = [] additions = 0 deletions = 0 if diff.diff: # Handle both bytes and str raw_diff = diff.diff patch_text = ( raw_diff.decode("utf-8", errors="replace") if isinstance(raw_diff, bytes) else raw_diff ) # Parse hunks (simplified) for line in patch_text.split("\n"): if line.startswith("+") and not line.startswith("+++"): additions += 1 elif line.startswith("-") and not line.startswith("---"): deletions += 1 # Add as single hunk for now hunks.append( DiffHunk( old_start=1, old_lines=deletions, new_start=1, new_lines=additions, content=patch_text[: self.settings.git_max_diff_lines], ) ) file_diffs.append( FileDiff( path=path, change_type=change_type, old_path=diff.a_path if diff.renamed else None, hunks=hunks, additions=additions, deletions=deletions, is_binary=diff.diff is None and not diff.deleted_file, ).to_dict() ) total_additions += additions total_deletions += deletions return DiffResult( project_id="", # Set by caller base=base, head=head, files=file_diffs, total_additions=total_additions, total_deletions=total_deletions, files_changed=len(file_diffs), ) except GitCommandError as e: raise GitError(f"Diff failed: {e}") return await run_in_executor(_diff) # Log operations async def log( self, ref: str | None = None, limit: int = 20, skip: int = 0, path: str | None = None, ) -> LogResult: """ Get commit log. Args: ref: Reference to start from limit: Max commits to return skip: Commits to skip path: Filter by path Returns: LogResult with commit history """ def _log() -> LogResult: repo = self.repo commits = [] try: kwargs: dict[str, Any] = { "max_count": limit, "skip": skip, } if path: kwargs["paths"] = path if ref: iterator = repo.iter_commits(ref, **kwargs) else: iterator = repo.iter_commits(**kwargs) for commit in iterator: # Handle message that can be bytes msg = commit.message message_str = ( msg.decode("utf-8", errors="replace") if isinstance(msg, bytes) else msg ) commits.append( CommitInfo( sha=commit.hexsha, short_sha=commit.hexsha[:7], message=message_str, author_name=commit.author.name or "Unknown", author_email=commit.author.email or "", authored_date=datetime.fromtimestamp( commit.authored_date, tz=UTC ), committer_name=commit.committer.name or "Unknown", committer_email=commit.committer.email or "", committed_date=datetime.fromtimestamp( commit.committed_date, tz=UTC ), parents=[p.hexsha for p in commit.parents], ).to_dict() ) return LogResult( project_id="", # Set by caller commits=commits, total_commits=len(commits), ) except GitCommandError as e: raise GitError(f"Log failed: {e}") return await run_in_executor(_log) # Reset operations async def reset( self, ref: str = "HEAD", mode: str = "mixed", files: list[str] | None = None, ) -> bool: """ Reset to a ref. Args: ref: Reference to reset to mode: Reset mode (soft, mixed, hard) files: Specific files to reset Returns: True if successful """ def _reset() -> bool: repo = self.repo try: if files: # Reset specific files repo.index.reset(commit=ref, paths=files) else: # Full reset if mode == "soft": repo.head.reset(ref, index=False, working_tree=False) elif mode == "mixed": repo.head.reset(ref, index=True, working_tree=False) elif mode == "hard": repo.head.reset(ref, index=True, working_tree=True) else: raise GitError(f"Invalid reset mode: {mode}") return True except GitCommandError as e: raise GitError(f"Reset failed: {e}") return await run_in_executor(_reset) # Stash operations async def stash(self, message: str | None = None) -> str | None: """ Stash changes. Args: message: Optional stash message Returns: Stash reference or None if nothing to stash """ def _stash() -> str | None: repo = self.repo try: if message: result = repo.git.stash("push", "-m", message) else: result = repo.git.stash("push") if "No local changes to save" in result: return None return repo.git.stash("list").split("\n")[0].split(":")[0] except GitCommandError as e: raise GitError(f"Stash failed: {e}") return await run_in_executor(_stash) async def stash_pop(self, stash_ref: str | None = None) -> bool: """ Pop stashed changes. Args: stash_ref: Specific stash to pop Returns: True if successful """ def _stash_pop() -> bool: repo = self.repo try: if stash_ref: repo.git.stash("pop", stash_ref) else: repo.git.stash("pop") return True except GitCommandError as e: if "conflict" in str(e).lower(): raise MergeConflictError([]) raise GitError(f"Stash pop failed: {e}") return await run_in_executor(_stash_pop) # Utility methods def _diff_to_change_type(self, change_type: str | None) -> FileChangeType: """Convert GitPython change type to our enum.""" if change_type is None: return FileChangeType.MODIFIED mapping = { "A": FileChangeType.ADDED, "M": FileChangeType.MODIFIED, "D": FileChangeType.DELETED, "R": FileChangeType.RENAMED, "C": FileChangeType.COPIED, } return mapping.get(change_type, FileChangeType.MODIFIED) async def is_valid_ref(self, ref: str) -> bool: """Check if a reference is valid.""" def _check() -> bool: try: self.repo.commit(ref) return True except Exception: return False return await run_in_executor(_check) async def get_remote_url(self, remote: str = "origin") -> str | None: """Get the URL for a remote.""" def _get_url() -> str | None: repo = self.repo if remote in [r.name for r in repo.remotes]: return repo.remote(remote).url return None return await run_in_executor(_get_url) async def set_config(self, key: str, value: str, global_: bool = False) -> None: """Set a git config value.""" def _set_config() -> None: repo = self.repo with repo.config_writer("global" if global_ else "repository") as cw: section, option = key.rsplit(".", 1) cw.set_value(section, option, value) await run_in_executor(_set_config) async def get_config(self, key: str) -> str | None: """Get a git config value.""" def _get_config() -> str | None: repo = self.repo try: cr = repo.config_reader() section, option = key.rsplit(".", 1) value = cr.get_value(section, option) return str(value) if value is not None else None except Exception: return None return await run_in_executor(_get_config)