**feat(git-ops): enhance MCP server with Git provider updates and SSRF protection**

- Added `mcp-git-ops` service to `docker-compose.dev.yml` with health checks and configurations.
- Integrated SSRF protection in repository URL validation for enhanced security.
- Expanded `pyproject.toml` mypy settings and adjusted code to meet stricter type checking.
- Improved workspace management and GitWrapper operations with error handling refinements.
- Updated input validation, branching, and repository operations to align with new error structure.
- Shut down thread pool executor gracefully during server cleanup.
This commit is contained in:
2026-01-07 09:17:00 +01:00
parent 1779239c07
commit 76d7de5334
11 changed files with 781 additions and 181 deletions

View File

@@ -47,6 +47,7 @@ help:
@echo " cd backend && make help - Backend-specific commands" @echo " cd backend && make help - Backend-specific commands"
@echo " cd mcp-servers/llm-gateway && make - LLM Gateway commands" @echo " cd mcp-servers/llm-gateway && make - LLM Gateway commands"
@echo " cd mcp-servers/knowledge-base && make - Knowledge Base commands" @echo " cd mcp-servers/knowledge-base && make - Knowledge Base commands"
@echo " cd mcp-servers/git-ops && make - Git Operations commands"
@echo " cd frontend && npm run - Frontend-specific commands" @echo " cd frontend && npm run - Frontend-specific commands"
# ============================================================================ # ============================================================================
@@ -138,6 +139,9 @@ test-mcp:
@echo "" @echo ""
@echo "=== Knowledge Base ===" @echo "=== Knowledge Base ==="
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v @cd mcp-servers/knowledge-base && uv run pytest tests/ -v
@echo ""
@echo "=== Git Operations ==="
@cd mcp-servers/git-ops && IS_TEST=True uv run pytest tests/ -v
test-frontend: test-frontend:
@echo "Running frontend tests..." @echo "Running frontend tests..."
@@ -158,6 +162,9 @@ test-cov:
@echo "" @echo ""
@echo "=== Knowledge Base Coverage ===" @echo "=== Knowledge Base Coverage ==="
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v --cov=. --cov-report=term-missing @cd mcp-servers/knowledge-base && uv run pytest tests/ -v --cov=. --cov-report=term-missing
@echo ""
@echo "=== Git Operations Coverage ==="
@cd mcp-servers/git-ops && IS_TEST=True uv run pytest tests/ -v --cov=. --cov-report=term-missing
test-integration: test-integration:
@echo "Running MCP integration tests..." @echo "Running MCP integration tests..."
@@ -178,6 +185,9 @@ format-all:
@echo "Formatting Knowledge Base..." @echo "Formatting Knowledge Base..."
@cd mcp-servers/knowledge-base && make format @cd mcp-servers/knowledge-base && make format
@echo "" @echo ""
@echo "Formatting Git Operations..."
@cd mcp-servers/git-ops && make format
@echo ""
@echo "Formatting frontend..." @echo "Formatting frontend..."
@cd frontend && npm run format @cd frontend && npm run format
@echo "" @echo ""
@@ -197,6 +207,9 @@ validate:
@echo "Validating Knowledge Base..." @echo "Validating Knowledge Base..."
@cd mcp-servers/knowledge-base && make validate @cd mcp-servers/knowledge-base && make validate
@echo "" @echo ""
@echo "Validating Git Operations..."
@cd mcp-servers/git-ops && make validate
@echo ""
@echo "All validations passed!" @echo "All validations passed!"
validate-all: validate validate-all: validate

View File

@@ -96,6 +96,38 @@ services:
- app-network - app-network
restart: unless-stopped restart: unless-stopped
mcp-git-ops:
build:
context: ./mcp-servers/git-ops
dockerfile: Dockerfile
ports:
- "8003:8003"
env_file:
- .env
environment:
# GIT_OPS_ prefix required by pydantic-settings config
- GIT_OPS_HOST=0.0.0.0
- GIT_OPS_PORT=8003
- GIT_OPS_REDIS_URL=redis://redis:6379/3
- GIT_OPS_GITEA_BASE_URL=${GITEA_BASE_URL}
- GIT_OPS_GITEA_TOKEN=${GITEA_TOKEN}
- GIT_OPS_GITHUB_TOKEN=${GITHUB_TOKEN}
- ENVIRONMENT=development
volumes:
- git_workspaces_dev:/workspaces
depends_on:
redis:
condition: service_healthy
healthcheck:
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8003/health').raise_for_status()"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
networks:
- app-network
restart: unless-stopped
backend: backend:
build: build:
context: ./backend context: ./backend
@@ -119,6 +151,7 @@ services:
# MCP Server URLs # MCP Server URLs
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001 - LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002 - KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
- GIT_OPS_URL=http://mcp-git-ops:8003
depends_on: depends_on:
db: db:
condition: service_healthy condition: service_healthy
@@ -128,6 +161,8 @@ services:
condition: service_healthy condition: service_healthy
mcp-knowledge-base: mcp-knowledge-base:
condition: service_healthy condition: service_healthy
mcp-git-ops:
condition: service_healthy
healthcheck: healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"] test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 10s interval: 10s
@@ -155,6 +190,7 @@ services:
# MCP Server URLs (agents need access to MCP) # MCP Server URLs (agents need access to MCP)
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001 - LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002 - KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
- GIT_OPS_URL=http://mcp-git-ops:8003
depends_on: depends_on:
db: db:
condition: service_healthy condition: service_healthy
@@ -164,6 +200,8 @@ services:
condition: service_healthy condition: service_healthy
mcp-knowledge-base: mcp-knowledge-base:
condition: service_healthy condition: service_healthy
mcp-git-ops:
condition: service_healthy
networks: networks:
- app-network - app-network
command: ["celery", "-A", "app.celery_app", "worker", "-Q", "agent", "-l", "info", "-c", "4"] command: ["celery", "-A", "app.celery_app", "worker", "-Q", "agent", "-l", "info", "-c", "4"]
@@ -181,11 +219,14 @@ services:
- DATABASE_URL=${DATABASE_URL} - DATABASE_URL=${DATABASE_URL}
- REDIS_URL=redis://redis:6379/0 - REDIS_URL=redis://redis:6379/0
- CELERY_QUEUE=git - CELERY_QUEUE=git
- GIT_OPS_URL=http://mcp-git-ops:8003
depends_on: depends_on:
db: db:
condition: service_healthy condition: service_healthy
redis: redis:
condition: service_healthy condition: service_healthy
mcp-git-ops:
condition: service_healthy
networks: networks:
- app-network - app-network
command: ["celery", "-A", "app.celery_app", "worker", "-Q", "git", "-l", "info", "-c", "2"] command: ["celery", "-A", "app.celery_app", "worker", "-Q", "git", "-l", "info", "-c", "2"]
@@ -260,6 +301,7 @@ services:
volumes: volumes:
postgres_data_dev: postgres_data_dev:
redis_data_dev: redis_data_dev:
git_workspaces_dev:
frontend_dev_modules: frontend_dev_modules:
frontend_dev_next: frontend_dev_next:

View File

@@ -0,0 +1,88 @@
.PHONY: help install install-dev lint lint-fix format format-check type-check test test-cov validate clean run
# Ensure commands in this project don't inherit an external Python virtualenv
# (prevents uv warnings about mismatched VIRTUAL_ENV when running from repo root)
unexport VIRTUAL_ENV
# Default target
help:
@echo "Git Operations MCP Server - Development Commands"
@echo ""
@echo "Setup:"
@echo " make install - Install production dependencies"
@echo " make install-dev - Install development dependencies"
@echo ""
@echo "Quality Checks:"
@echo " make lint - Run Ruff linter"
@echo " make lint-fix - Run Ruff linter with auto-fix"
@echo " make format - Format code with Ruff"
@echo " make format-check - Check if code is formatted"
@echo " make type-check - Run mypy type checker"
@echo ""
@echo "Testing:"
@echo " make test - Run pytest"
@echo " make test-cov - Run pytest with coverage"
@echo ""
@echo "All-in-one:"
@echo " make validate - Run all checks (lint + format + types)"
@echo ""
@echo "Running:"
@echo " make run - Run the server locally"
@echo ""
@echo "Cleanup:"
@echo " make clean - Remove cache and build artifacts"
# Setup
install:
@echo "Installing production dependencies..."
@uv pip install -e .
install-dev:
@echo "Installing development dependencies..."
@uv pip install -e ".[dev]"
# Quality checks
lint:
@echo "Running Ruff linter..."
@uv run ruff check .
lint-fix:
@echo "Running Ruff linter with auto-fix..."
@uv run ruff check --fix .
format:
@echo "Formatting code..."
@uv run ruff format .
format-check:
@echo "Checking code formatting..."
@uv run ruff format --check .
type-check:
@echo "Running mypy..."
@uv run python -m mypy server.py config.py models.py exceptions.py git_wrapper.py workspace.py providers/ --explicit-package-bases
# Testing
test:
@echo "Running tests..."
@IS_TEST=True uv run pytest tests/ -v
test-cov:
@echo "Running tests with coverage..."
@IS_TEST=True uv run pytest tests/ -v --cov=. --cov-report=term-missing --cov-report=html
# All-in-one validation
validate: lint format-check type-check
@echo "All validations passed!"
# Running
run:
@echo "Starting Git Operations server..."
@uv run python server.py
# Cleanup
clean:
@echo "Cleaning up..."
@rm -rf __pycache__ .pytest_cache .mypy_cache .ruff_cache .coverage htmlcov
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
@find . -type f -name "*.pyc" -delete 2>/dev/null || true

View File

@@ -73,7 +73,7 @@ class GitOpsError(Exception):
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for MCP response.""" """Convert to dictionary for MCP response."""
result = { result: dict[str, Any] = {
"error": self.message, "error": self.message,
"code": self.code.value, "code": self.code.value,
} }
@@ -325,9 +325,7 @@ class PRNotFoundError(PRError):
class APIError(ProviderError): class APIError(ProviderError):
"""Provider API error.""" """Provider API error."""
def __init__( def __init__(self, provider: str, status_code: int, message: str) -> None:
self, provider: str, status_code: int, message: str
) -> None:
super().__init__( super().__init__(
f"{provider} API error ({status_code}): {message}", f"{provider} API error ({status_code}): {message}",
ErrorCode.API_ERROR, ErrorCode.API_ERROR,

View File

@@ -52,6 +52,21 @@ from models import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def sanitize_url_for_logging(url: str) -> str:
"""
Remove any credentials from a URL before logging.
Handles URLs like:
- https://token@github.com/owner/repo.git
- https://user:password@github.com/owner/repo.git
- git@github.com:owner/repo.git (unchanged, no credentials)
"""
# Pattern to match https://[credentials@]host/path
sanitized = re.sub(r"(https?://)([^@]+@)", r"\1***@", url)
return sanitized
# Thread pool for blocking git operations # Thread pool for blocking git operations
_executor: ThreadPoolExecutor | None = None _executor: ThreadPoolExecutor | None = None
@@ -81,7 +96,7 @@ class GitWrapper:
def __init__( def __init__(
self, self,
workspace_path: Path, workspace_path: Path | str,
settings: Settings | None = None, settings: Settings | None = None,
) -> None: ) -> None:
""" """
@@ -91,7 +106,9 @@ class GitWrapper:
workspace_path: Path to the git workspace workspace_path: Path to the git workspace
settings: Optional settings override settings: Optional settings override
""" """
self.workspace_path = workspace_path self.workspace_path = (
Path(workspace_path) if isinstance(workspace_path, str) else workspace_path
)
self.settings = settings or get_settings() self.settings = settings or get_settings()
self._repo: GitRepo | None = None self._repo: GitRepo | None = None
@@ -175,8 +192,10 @@ class GitWrapper:
) )
except GitCommandError as e: except GitCommandError as e:
logger.error(f"Clone failed: {e}") # Sanitize URLs in error messages to prevent credential leakage
raise CloneError(repo_url, str(e)) error_msg = sanitize_url_for_logging(str(e))
logger.error(f"Clone failed: {error_msg}")
raise CloneError(sanitize_url_for_logging(repo_url), error_msg)
return await run_in_executor(_do_clone) return await run_in_executor(_do_clone)
@@ -200,9 +219,10 @@ class GitWrapper:
staged = [] staged = []
for diff in repo.index.diff("HEAD"): for diff in repo.index.diff("HEAD"):
change_type = self._diff_to_change_type(diff.change_type) change_type = self._diff_to_change_type(diff.change_type)
path = diff.b_path or diff.a_path or ""
staged.append( staged.append(
FileChange( FileChange(
path=diff.b_path or diff.a_path, path=path,
change_type=change_type, change_type=change_type,
old_path=diff.a_path if diff.renamed else None, old_path=diff.a_path if diff.renamed else None,
).to_dict() ).to_dict()
@@ -212,9 +232,10 @@ class GitWrapper:
unstaged = [] unstaged = []
for diff in repo.index.diff(None): for diff in repo.index.diff(None):
change_type = self._diff_to_change_type(diff.change_type) change_type = self._diff_to_change_type(diff.change_type)
path = diff.b_path or diff.a_path or ""
unstaged.append( unstaged.append(
FileChange( FileChange(
path=diff.b_path or diff.a_path, path=path,
change_type=change_type, change_type=change_type,
).to_dict() ).to_dict()
) )
@@ -228,10 +249,18 @@ class GitWrapper:
tracking = repo.active_branch.tracking_branch() tracking = repo.active_branch.tracking_branch()
if tracking: if tracking:
ahead = len( ahead = len(
list(repo.iter_commits(f"{tracking.name}..{repo.active_branch.name}")) list(
repo.iter_commits(
f"{tracking.name}..{repo.active_branch.name}"
)
)
) )
behind = len( behind = len(
list(repo.iter_commits(f"{repo.active_branch.name}..{tracking.name}")) list(
repo.iter_commits(
f"{repo.active_branch.name}..{tracking.name}"
)
)
) )
except Exception: except Exception:
pass # No tracking branch pass # No tracking branch
@@ -361,6 +390,12 @@ class GitWrapper:
local_branches = [] local_branches = []
for branch in repo.branches: for branch in repo.branches:
tracking = branch.tracking_branch() tracking = branch.tracking_branch()
msg = branch.commit.message
commit_msg = (
msg.decode("utf-8", errors="replace")
if isinstance(msg, bytes)
else msg
).split("\n")[0]
local_branches.append( local_branches.append(
BranchInfo( BranchInfo(
name=branch.name, name=branch.name,
@@ -368,7 +403,7 @@ class GitWrapper:
is_remote=False, is_remote=False,
tracking_branch=tracking.name if tracking else None, tracking_branch=tracking.name if tracking else None,
commit_sha=branch.commit.hexsha, commit_sha=branch.commit.hexsha,
commit_message=branch.commit.message.split("\n")[0], commit_message=commit_msg,
).to_dict() ).to_dict()
) )
@@ -379,13 +414,19 @@ class GitWrapper:
# Skip HEAD refs # Skip HEAD refs
if ref.name.endswith("/HEAD"): if ref.name.endswith("/HEAD"):
continue continue
msg = ref.commit.message
commit_msg = (
msg.decode("utf-8", errors="replace")
if isinstance(msg, bytes)
else msg
).split("\n")[0]
remote_branches.append( remote_branches.append(
BranchInfo( BranchInfo(
name=ref.name, name=ref.name,
is_current=False, is_current=False,
is_remote=True, is_remote=True,
commit_sha=ref.commit.hexsha, commit_sha=ref.commit.hexsha,
commit_message=ref.commit.message.split("\n")[0], commit_message=commit_msg,
).to_dict() ).to_dict()
) )
@@ -485,7 +526,11 @@ class GitWrapper:
repo.git.add("-A") repo.git.add("-A")
# Check if there's anything to commit # Check if there's anything to commit
if not allow_empty and not repo.index.diff("HEAD") and not repo.untracked_files: if (
not allow_empty
and not repo.index.diff("HEAD")
and not repo.untracked_files
):
raise CommitError("Nothing to commit") raise CommitError("Nothing to commit")
# Build author # Build author
@@ -613,7 +658,7 @@ class GitWrapper:
try: try:
# Build push info # Build push info
push_info_list = [] push_info_list: list[Any] = []
if remote not in [r.name for r in repo.remotes]: if remote not in [r.name for r in repo.remotes]:
raise PushError(push_branch, f"Remote not found: {remote}") raise PushError(push_branch, f"Remote not found: {remote}")
@@ -716,9 +761,7 @@ class GitWrapper:
repo.git.pull(remote, pull_branch) repo.git.pull(remote, pull_branch)
# Count new commits # Count new commits
commits_received = len( commits_received = len(list(repo.iter_commits(f"{head_before}..HEAD")))
list(repo.iter_commits(f"{head_before}..HEAD"))
)
# Check if fast-forward # Check if fast-forward
fast_forward = commits_received > 0 and not repo.head.commit.parents fast_forward = commits_received > 0 and not repo.head.commit.parents
@@ -733,10 +776,9 @@ class GitWrapper:
except GitCommandError as e: except GitCommandError as e:
error_msg = str(e) error_msg = str(e)
if "conflict" in error_msg.lower(): if "conflict" in error_msg.lower():
# Get conflicting files # Get conflicting files - keys are paths directly
conflicts = [ conflicts = [
item.a_path str(path) for path in repo.index.unmerged_blobs().keys()
for item in repo.index.unmerged_blobs().keys()
] ]
raise MergeConflictError(conflicts) raise MergeConflictError(conflicts)
raise PullError(pull_branch, error_msg) raise PullError(pull_branch, error_msg)
@@ -823,7 +865,7 @@ class GitWrapper:
continue continue
change_type = self._diff_to_change_type(diff.change_type) change_type = self._diff_to_change_type(diff.change_type)
path = diff.b_path or diff.a_path path = diff.b_path or diff.a_path or ""
# Parse hunks from patch # Parse hunks from patch
hunks = [] hunks = []
@@ -831,7 +873,13 @@ class GitWrapper:
deletions = 0 deletions = 0
if diff.diff: if diff.diff:
patch_text = diff.diff.decode("utf-8", errors="replace") # Handle both bytes and str
raw_diff = diff.diff
patch_text = (
raw_diff.decode("utf-8", errors="replace")
if isinstance(raw_diff, bytes)
else raw_diff
)
# Parse hunks (simplified) # Parse hunks (simplified)
for line in patch_text.split("\n"): for line in patch_text.split("\n"):
if line.startswith("+") and not line.startswith("+++"): if line.startswith("+") and not line.startswith("+++"):
@@ -921,18 +969,25 @@ class GitWrapper:
iterator = repo.iter_commits(**kwargs) iterator = repo.iter_commits(**kwargs)
for commit in iterator: for commit in iterator:
# Handle message that can be bytes
msg = commit.message
message_str = (
msg.decode("utf-8", errors="replace")
if isinstance(msg, bytes)
else msg
)
commits.append( commits.append(
CommitInfo( CommitInfo(
sha=commit.hexsha, sha=commit.hexsha,
short_sha=commit.hexsha[:7], short_sha=commit.hexsha[:7],
message=commit.message, message=message_str,
author_name=commit.author.name, author_name=commit.author.name or "Unknown",
author_email=commit.author.email, author_email=commit.author.email or "",
authored_date=datetime.fromtimestamp( authored_date=datetime.fromtimestamp(
commit.authored_date, tz=UTC commit.authored_date, tz=UTC
), ),
committer_name=commit.committer.name, committer_name=commit.committer.name or "Unknown",
committer_email=commit.committer.email, committer_email=commit.committer.email or "",
committed_date=datetime.fromtimestamp( committed_date=datetime.fromtimestamp(
commit.committed_date, tz=UTC commit.committed_date, tz=UTC
), ),
@@ -1052,8 +1107,10 @@ class GitWrapper:
# Utility methods # Utility methods
def _diff_to_change_type(self, change_type: str) -> FileChangeType: def _diff_to_change_type(self, change_type: str | None) -> FileChangeType:
"""Convert GitPython change type to our enum.""" """Convert GitPython change type to our enum."""
if change_type is None:
return FileChangeType.MODIFIED
mapping = { mapping = {
"A": FileChangeType.ADDED, "A": FileChangeType.ADDED,
"M": FileChangeType.MODIFIED, "M": FileChangeType.MODIFIED,
@@ -1105,7 +1162,8 @@ class GitWrapper:
try: try:
cr = repo.config_reader() cr = repo.config_reader()
section, option = key.rsplit(".", 1) section, option = key.rsplit(".", 1)
return cr.get_value(section, option) value = cr.get_value(section, option)
return str(value) if value is not None else None
except Exception: except Exception:
return None return None

View File

@@ -257,7 +257,9 @@ class WorkspaceInfo:
"last_accessed": self.last_accessed.isoformat(), "last_accessed": self.last_accessed.isoformat(),
"size_bytes": self.size_bytes, "size_bytes": self.size_bytes,
"lock_holder": self.lock_holder, "lock_holder": self.lock_holder,
"lock_expires": self.lock_expires.isoformat() if self.lock_expires else None, "lock_expires": self.lock_expires.isoformat()
if self.lock_expires
else None,
} }
@@ -270,7 +272,9 @@ class CloneRequest(BaseModel):
project_id: str = Field(..., description="Project ID for scoping") project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request") agent_id: str = Field(..., description="Agent ID making the request")
repo_url: str = Field(..., description="Repository URL to clone") repo_url: str = Field(..., description="Repository URL to clone")
branch: str | None = Field(default=None, description="Branch to checkout after clone") branch: str | None = Field(
default=None, description="Branch to checkout after clone"
)
depth: int | None = Field( depth: int | None = Field(
default=None, ge=1, description="Shallow clone depth (None = full clone)" default=None, ge=1, description="Shallow clone depth (None = full clone)"
) )
@@ -407,7 +411,9 @@ class PushRequest(BaseModel):
project_id: str = Field(..., description="Project ID for scoping") project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request") agent_id: str = Field(..., description="Agent ID making the request")
branch: str | None = Field(default=None, description="Branch to push (None = current)") branch: str | None = Field(
default=None, description="Branch to push (None = current)"
)
remote: str = Field(default="origin", description="Remote name") remote: str = Field(default="origin", description="Remote name")
force: bool = Field(default=False, description="Force push") force: bool = Field(default=False, description="Force push")
set_upstream: bool = Field(default=True, description="Set upstream tracking") set_upstream: bool = Field(default=True, description="Set upstream tracking")
@@ -428,7 +434,9 @@ class PullRequest(BaseModel):
project_id: str = Field(..., description="Project ID for scoping") project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request") agent_id: str = Field(..., description="Agent ID making the request")
branch: str | None = Field(default=None, description="Branch to pull (None = current)") branch: str | None = Field(
default=None, description="Branch to pull (None = current)"
)
remote: str = Field(default="origin", description="Remote name") remote: str = Field(default="origin", description="Remote name")
rebase: bool = Field(default=False, description="Rebase instead of merge") rebase: bool = Field(default=False, description="Rebase instead of merge")
@@ -451,7 +459,9 @@ class DiffRequest(BaseModel):
project_id: str = Field(..., description="Project ID for scoping") project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request") agent_id: str = Field(..., description="Agent ID making the request")
base: str | None = Field(default=None, description="Base reference (None = working tree)") base: str | None = Field(
default=None, description="Base reference (None = working tree)"
)
head: str | None = Field(default=None, description="Head reference (None = HEAD)") head: str | None = Field(default=None, description="Head reference (None = HEAD)")
files: list[str] | None = Field(default=None, description="Specific files to diff") files: list[str] | None = Field(default=None, description="Specific files to diff")
context_lines: int = Field(default=3, ge=0, description="Context lines") context_lines: int = Field(default=3, ge=0, description="Context lines")
@@ -463,9 +473,7 @@ class DiffResult(BaseModel):
project_id: str = Field(..., description="Project ID") project_id: str = Field(..., description="Project ID")
base: str | None = Field(default=None, description="Base reference") base: str | None = Field(default=None, description="Base reference")
head: str | None = Field(default=None, description="Head reference") head: str | None = Field(default=None, description="Head reference")
files: list[dict[str, Any]] = Field( files: list[dict[str, Any]] = Field(default_factory=list, description="File diffs")
default_factory=list, description="File diffs"
)
total_additions: int = Field(default=0, description="Total lines added") total_additions: int = Field(default=0, description="Total lines added")
total_deletions: int = Field(default=0, description="Total lines removed") total_deletions: int = Field(default=0, description="Total lines removed")
files_changed: int = Field(default=0, description="Number of files changed") files_changed: int = Field(default=0, description="Number of files changed")
@@ -549,9 +557,7 @@ class ListPRsResult(BaseModel):
"""Result of listing pull requests.""" """Result of listing pull requests."""
success: bool = Field(..., description="Whether list succeeded") success: bool = Field(..., description="Whether list succeeded")
pull_requests: list[dict[str, Any]] = Field( pull_requests: list[dict[str, Any]] = Field(default_factory=list, description="PRs")
default_factory=list, description="PRs"
)
total_count: int = Field(default=0, description="Total matching PRs") total_count: int = Field(default=0, description="Total matching PRs")
error: str | None = Field(default=None, description="Error message if failed") error: str | None = Field(default=None, description="Error message if failed")
@@ -578,7 +584,9 @@ class MergePRResult(BaseModel):
success: bool = Field(..., description="Whether merge succeeded") success: bool = Field(..., description="Whether merge succeeded")
merge_commit_sha: str | None = Field(default=None, description="Merge commit SHA") merge_commit_sha: str | None = Field(default=None, description="Merge commit SHA")
branch_deleted: bool = Field(default=False, description="Whether branch was deleted") branch_deleted: bool = Field(
default=False, description="Whether branch was deleted"
)
error: str | None = Field(default=None, description="Error message if failed") error: str | None = Field(default=None, description="Error message if failed")
@@ -626,7 +634,9 @@ class LockWorkspaceRequest(BaseModel):
project_id: str = Field(..., description="Project ID") project_id: str = Field(..., description="Project ID")
agent_id: str = Field(..., description="Agent ID requesting lock") agent_id: str = Field(..., description="Agent ID requesting lock")
timeout: int = Field(default=300, ge=10, le=3600, description="Lock timeout seconds") timeout: int = Field(
default=300, ge=10, le=3600, description="Lock timeout seconds"
)
class LockWorkspaceResult(BaseModel): class LockWorkspaceResult(BaseModel):
@@ -634,7 +644,9 @@ class LockWorkspaceResult(BaseModel):
success: bool = Field(..., description="Whether lock acquired") success: bool = Field(..., description="Whether lock acquired")
lock_holder: str | None = Field(default=None, description="Current lock holder") lock_holder: str | None = Field(default=None, description="Current lock holder")
lock_expires: str | None = Field(default=None, description="Lock expiry ISO timestamp") lock_expires: str | None = Field(
default=None, description="Lock expiry ISO timestamp"
)
error: str | None = Field(default=None, description="Error message if failed") error: str | None = Field(default=None, description="Error message if failed")

View File

@@ -44,9 +44,7 @@ class BaseProvider(ABC):
# Repository operations # Repository operations
@abstractmethod @abstractmethod
async def get_repo_info( async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]:
self, owner: str, repo: str
) -> dict[str, Any]:
""" """
Get repository information. Get repository information.
@@ -60,9 +58,7 @@ class BaseProvider(ABC):
... ...
@abstractmethod @abstractmethod
async def get_default_branch( async def get_default_branch(self, owner: str, repo: str) -> str:
self, owner: str, repo: str
) -> str:
""" """
Get the default branch for a repository. Get the default branch for a repository.
@@ -112,9 +108,7 @@ class BaseProvider(ABC):
... ...
@abstractmethod @abstractmethod
async def get_pr( async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult:
self, owner: str, repo: str, pr_number: int
) -> GetPRResult:
""" """
Get a pull request by number. Get a pull request by number.
@@ -209,9 +203,7 @@ class BaseProvider(ABC):
... ...
@abstractmethod @abstractmethod
async def close_pr( async def close_pr(self, owner: str, repo: str, pr_number: int) -> UpdatePRResult:
self, owner: str, repo: str, pr_number: int
) -> UpdatePRResult:
""" """
Close a pull request without merging. Close a pull request without merging.
@@ -228,9 +220,7 @@ class BaseProvider(ABC):
# Branch operations via API (for operations that need to bypass local git) # Branch operations via API (for operations that need to bypass local git)
@abstractmethod @abstractmethod
async def delete_remote_branch( async def delete_remote_branch(self, owner: str, repo: str, branch: str) -> bool:
self, owner: str, repo: str, branch: str
) -> bool:
""" """
Delete a remote branch via API. Delete a remote branch via API.
@@ -379,9 +369,7 @@ class BaseProvider(ABC):
return ssh_match.group(1), ssh_match.group(2) return ssh_match.group(1), ssh_match.group(2)
# Handle HTTPS URLs: https://host/owner/repo.git # Handle HTTPS URLs: https://host/owner/repo.git
https_match = re.match( https_match = re.match(r"https?://[^/]+/([^/]+)/([^/]+?)(?:\.git)?$", repo_url)
r"https?://[^/]+/([^/]+)/([^/]+?)(?:\.git)?$", repo_url
)
if https_match: if https_match:
return https_match.group(1), https_match.group(2) return https_match.group(1), https_match.group(2)

View File

@@ -116,9 +116,7 @@ class GitHubProvider(BaseProvider):
if response.status_code == 403: if response.status_code == 403:
# Check for rate limiting # Check for rate limiting
if "rate limit" in response.text.lower(): if "rate limit" in response.text.lower():
raise APIError( raise APIError("github", 403, "GitHub API rate limit exceeded")
"github", 403, "GitHub API rate limit exceeded"
)
raise AuthenticationError( raise AuthenticationError(
"github", "Insufficient permissions for this operation" "github", "Insufficient permissions for this operation"
) )

View File

@@ -101,7 +101,7 @@ exclude_lines = [
"if TYPE_CHECKING:", "if TYPE_CHECKING:",
"if __name__ == .__main__.:", "if __name__ == .__main__.:",
] ]
fail_under = 65 # TODO: Increase to 80% once more tool tests are added fail_under = 78
show_missing = true show_missing = true
[tool.mypy] [tool.mypy]
@@ -111,6 +111,8 @@ warn_unused_ignores = false
disallow_untyped_defs = true disallow_untyped_defs = true
ignore_missing_imports = true ignore_missing_imports = true
plugins = ["pydantic.mypy"] plugins = ["pydantic.mypy"]
files = ["server.py", "config.py", "models.py", "exceptions.py", "git_wrapper.py", "workspace.py", "providers/"]
exclude = ["tests/"]
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = "tests.*" module = "tests.*"

View File

@@ -56,13 +56,56 @@ def _validate_branch(value: str) -> str | None:
def _validate_url(value: str) -> str | None: def _validate_url(value: str) -> str | None:
"""Validate repository URL format.""" """
Validate repository URL format with SSRF protection.
Validates the URL format and optionally checks against allowed hosts.
"""
if not isinstance(value, str): if not isinstance(value, str):
return "Repository URL must be a string" return "Repository URL must be a string"
if not value: if not value:
return "Repository URL is required" return "Repository URL is required"
if not URL_PATTERN.match(value): if not URL_PATTERN.match(value):
return "Invalid repository URL: must be a valid HTTPS or SSH git URL" return "Invalid repository URL: must be a valid HTTPS or SSH git URL"
# SSRF protection: check allowed hosts if configured
if _settings and _settings.allowed_hosts:
from urllib.parse import urlparse
parsed = urlparse(value)
hostname = parsed.hostname
# Block localhost and loopback addresses
blocked_hosts = {
"localhost",
"127.0.0.1",
"::1",
"0.0.0.0",
"169.254.169.254", # Cloud metadata endpoint
}
if hostname and hostname.lower() in blocked_hosts:
return f"Repository URL not allowed: blocked host '{hostname}'"
# Block private IP ranges (simplified check)
if hostname:
import ipaddress
try:
ip = ipaddress.ip_address(hostname)
if ip.is_private or ip.is_loopback or ip.is_link_local:
return (
f"Repository URL not allowed: private IP address '{hostname}'"
)
except ValueError:
pass # Not an IP address, continue with hostname check
# Check against allowed hosts list
if hostname and hostname.lower() not in [
h.lower() for h in _settings.allowed_hosts
]:
return f"Repository URL not allowed: host '{hostname}' not in allowed list"
return None return None
@@ -92,7 +135,8 @@ def _get_provider_for_url(repo_url: str) -> BaseProvider | None:
if "github.com" in url_lower or ( if "github.com" in url_lower or (
_settings.github_api_url _settings.github_api_url
and _settings.github_api_url != "https://api.github.com" and _settings.github_api_url != "https://api.github.com"
and _settings.github_api_url.replace("https://", "").replace("/api/v3", "") in url_lower and _settings.github_api_url.replace("https://", "").replace("/api/v3", "")
in url_lower
): ):
return _github_provider return _github_provider
@@ -166,6 +210,13 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
if _github_provider: if _github_provider:
await _github_provider.close() await _github_provider.close()
# Shutdown thread pool executor for git operations
from git_wrapper import _executor
if _executor:
logger.info("Shutting down git operations thread pool...")
_executor.shutdown(wait=True)
logger.info("Git Operations MCP Server shut down") logger.info("Git Operations MCP Server shut down")
@@ -479,7 +530,9 @@ async def clone_repository(
project_id: str = Field(..., description="Project ID for scoping"), project_id: str = Field(..., description="Project ID for scoping"),
agent_id: str = Field(..., description="Agent ID making the request"), agent_id: str = Field(..., description="Agent ID making the request"),
repo_url: str = Field(..., description="Repository URL to clone"), repo_url: str = Field(..., description="Repository URL to clone"),
branch: str | None = Field(default=None, description="Branch to checkout after clone"), branch: str | None = Field(
default=None, description="Branch to checkout after clone"
),
depth: int | None = Field(default=None, ge=1, description="Shallow clone depth"), depth: int | None = Field(default=None, ge=1, description="Shallow clone depth"),
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
@@ -490,11 +543,23 @@ async def clone_repository(
try: try:
# Validate inputs # Validate inputs
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_url(repo_url): if error := _validate_url(repo_url):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
# Create workspace # Create workspace
workspace = await _workspace_manager.create_workspace(project_id, repo_url) # type: ignore[union-attr] workspace = await _workspace_manager.create_workspace(project_id, repo_url) # type: ignore[union-attr]
@@ -504,7 +569,9 @@ async def clone_repository(
# Clone repository # Clone repository
git = GitWrapper(workspace.path, _settings) git = GitWrapper(workspace.path, _settings)
result = await git.clone(repo_url, branch=branch, depth=depth, auth_token=auth_token) result = await git.clone(
repo_url, branch=branch, depth=depth, auth_token=auth_token
)
# Update workspace metadata # Update workspace metadata
await _workspace_manager.update_workspace_branch(project_id, result.branch) # type: ignore[union-attr] await _workspace_manager.update_workspace_branch(project_id, result.branch) # type: ignore[union-attr]
@@ -522,14 +589,20 @@ async def clone_repository(
return {"success": False, "error": e.message, "code": e.code.value} return {"success": False, "error": e.message, "code": e.code.value}
except Exception as e: except Exception as e:
logger.error(f"Unexpected clone error: {e}") logger.error(f"Unexpected clone error: {e}")
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value} return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@mcp.tool() @mcp.tool()
async def git_status( async def git_status(
project_id: str = Field(..., description="Project ID for scoping"), project_id: str = Field(..., description="Project ID for scoping"),
agent_id: str = Field(..., description="Agent ID making the request"), agent_id: str = Field(..., description="Agent ID making the request"),
include_untracked: bool = Field(default=True, description="Include untracked files"), include_untracked: bool = Field(
default=True, description="Include untracked files"
),
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Get git status for a project workspace. Get git status for a project workspace.
@@ -538,13 +611,25 @@ async def git_status(
""" """
try: try:
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr] workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
if not workspace: if not workspace:
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value} return {
"success": False,
"error": f"Workspace not found for project: {project_id}",
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
}
git = GitWrapper(workspace.path, _settings) git = GitWrapper(workspace.path, _settings)
result = await git.status(include_untracked=include_untracked) result = await git.status(include_untracked=include_untracked)
@@ -567,7 +652,11 @@ async def git_status(
return {"success": False, "error": e.message, "code": e.code.value} return {"success": False, "error": e.message, "code": e.code.value}
except Exception as e: except Exception as e:
logger.error(f"Unexpected status error: {e}") logger.error(f"Unexpected status error: {e}")
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value} return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
# MCP Tools - Branch Operations # MCP Tools - Branch Operations
@@ -578,7 +667,9 @@ async def create_branch(
project_id: str = Field(..., description="Project ID for scoping"), project_id: str = Field(..., description="Project ID for scoping"),
agent_id: str = Field(..., description="Agent ID making the request"), agent_id: str = Field(..., description="Agent ID making the request"),
branch_name: str = Field(..., description="Name for the new branch"), branch_name: str = Field(..., description="Name for the new branch"),
from_ref: str | None = Field(default=None, description="Reference to create from (default: HEAD)"), from_ref: str | None = Field(
default=None, description="Reference to create from (default: HEAD)"
),
checkout: bool = Field(default=True, description="Checkout after creation"), checkout: bool = Field(default=True, description="Checkout after creation"),
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
@@ -586,18 +677,36 @@ async def create_branch(
""" """
try: try:
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_branch(branch_name): if error := _validate_branch(branch_name):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr] workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
if not workspace: if not workspace:
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value} return {
"success": False,
"error": f"Workspace not found for project: {project_id}",
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
}
git = GitWrapper(workspace.path, _settings) git = GitWrapper(workspace.path, _settings)
result = await git.create_branch(branch_name, from_ref=from_ref, checkout=checkout) result = await git.create_branch(
branch_name, from_ref=from_ref, checkout=checkout
)
if checkout: if checkout:
await _workspace_manager.update_workspace_branch(project_id, branch_name) # type: ignore[union-attr] await _workspace_manager.update_workspace_branch(project_id, branch_name) # type: ignore[union-attr]
@@ -614,7 +723,11 @@ async def create_branch(
return {"success": False, "error": e.message, "code": e.code.value} return {"success": False, "error": e.message, "code": e.code.value}
except Exception as e: except Exception as e:
logger.error(f"Unexpected create branch error: {e}") logger.error(f"Unexpected create branch error: {e}")
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value} return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@mcp.tool() @mcp.tool()
@@ -628,13 +741,25 @@ async def list_branches(
""" """
try: try:
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr] workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
if not workspace: if not workspace:
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value} return {
"success": False,
"error": f"Workspace not found for project: {project_id}",
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
}
git = GitWrapper(workspace.path, _settings) git = GitWrapper(workspace.path, _settings)
result = await git.list_branches(include_remote=include_remote) result = await git.list_branches(include_remote=include_remote)
@@ -652,7 +777,11 @@ async def list_branches(
return {"success": False, "error": e.message, "code": e.code.value} return {"success": False, "error": e.message, "code": e.code.value}
except Exception as e: except Exception as e:
logger.error(f"Unexpected list branches error: {e}") logger.error(f"Unexpected list branches error: {e}")
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value} return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@mcp.tool() @mcp.tool()
@@ -668,13 +797,25 @@ async def checkout(
""" """
try: try:
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr] workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
if not workspace: if not workspace:
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value} return {
"success": False,
"error": f"Workspace not found for project: {project_id}",
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
}
git = GitWrapper(workspace.path, _settings) git = GitWrapper(workspace.path, _settings)
result = await git.checkout(ref, create_branch=create_branch, force=force) result = await git.checkout(ref, create_branch=create_branch, force=force)
@@ -692,7 +833,11 @@ async def checkout(
return {"success": False, "error": e.message, "code": e.code.value} return {"success": False, "error": e.message, "code": e.code.value}
except Exception as e: except Exception as e:
logger.error(f"Unexpected checkout error: {e}") logger.error(f"Unexpected checkout error: {e}")
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value} return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
# MCP Tools - Commit Operations # MCP Tools - Commit Operations
@@ -703,7 +848,9 @@ async def commit(
project_id: str = Field(..., description="Project ID for scoping"), project_id: str = Field(..., description="Project ID for scoping"),
agent_id: str = Field(..., description="Agent ID making the request"), agent_id: str = Field(..., description="Agent ID making the request"),
message: str = Field(..., description="Commit message"), message: str = Field(..., description="Commit message"),
files: list[str] | None = Field(default=None, description="Specific files to commit (None = all staged)"), files: list[str] | None = Field(
default=None, description="Specific files to commit (None = all staged)"
),
author_name: str | None = Field(default=None, description="Author name override"), author_name: str | None = Field(default=None, description="Author name override"),
author_email: str | None = Field(default=None, description="Author email override"), author_email: str | None = Field(default=None, description="Author email override"),
) -> dict[str, Any]: ) -> dict[str, Any]:
@@ -714,13 +861,25 @@ async def commit(
""" """
try: try:
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr] workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
if not workspace: if not workspace:
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value} return {
"success": False,
"error": f"Workspace not found for project: {project_id}",
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
}
git = GitWrapper(workspace.path, _settings) git = GitWrapper(workspace.path, _settings)
result = await git.commit( result = await git.commit(
@@ -745,14 +904,20 @@ async def commit(
return {"success": False, "error": e.message, "code": e.code.value} return {"success": False, "error": e.message, "code": e.code.value}
except Exception as e: except Exception as e:
logger.error(f"Unexpected commit error: {e}") logger.error(f"Unexpected commit error: {e}")
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value} return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@mcp.tool() @mcp.tool()
async def push( async def push(
project_id: str = Field(..., description="Project ID for scoping"), project_id: str = Field(..., description="Project ID for scoping"),
agent_id: str = Field(..., description="Agent ID making the request"), agent_id: str = Field(..., description="Agent ID making the request"),
branch: str | None = Field(default=None, description="Branch to push (None = current)"), branch: str | None = Field(
default=None, description="Branch to push (None = current)"
),
remote: str = Field(default="origin", description="Remote name"), remote: str = Field(default="origin", description="Remote name"),
force: bool = Field(default=False, description="Force push"), force: bool = Field(default=False, description="Force push"),
set_upstream: bool = Field(default=True, description="Set upstream tracking"), set_upstream: bool = Field(default=True, description="Set upstream tracking"),
@@ -762,13 +927,25 @@ async def push(
""" """
try: try:
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr] workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
if not workspace: if not workspace:
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value} return {
"success": False,
"error": f"Workspace not found for project: {project_id}",
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
}
# Get auth token appropriate for the repository URL # Get auth token appropriate for the repository URL
auth_token = _get_auth_token_for_url(workspace.repo_url or "") auth_token = _get_auth_token_for_url(workspace.repo_url or "")
@@ -794,14 +971,20 @@ async def push(
return {"success": False, "error": e.message, "code": e.code.value} return {"success": False, "error": e.message, "code": e.code.value}
except Exception as e: except Exception as e:
logger.error(f"Unexpected push error: {e}") logger.error(f"Unexpected push error: {e}")
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value} return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@mcp.tool() @mcp.tool()
async def pull( async def pull(
project_id: str = Field(..., description="Project ID for scoping"), project_id: str = Field(..., description="Project ID for scoping"),
agent_id: str = Field(..., description="Agent ID making the request"), agent_id: str = Field(..., description="Agent ID making the request"),
branch: str | None = Field(default=None, description="Branch to pull (None = current)"), branch: str | None = Field(
default=None, description="Branch to pull (None = current)"
),
remote: str = Field(default="origin", description="Remote name"), remote: str = Field(default="origin", description="Remote name"),
rebase: bool = Field(default=False, description="Rebase instead of merge"), rebase: bool = Field(default=False, description="Rebase instead of merge"),
) -> dict[str, Any]: ) -> dict[str, Any]:
@@ -810,13 +993,25 @@ async def pull(
""" """
try: try:
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr] workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
if not workspace: if not workspace:
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value} return {
"success": False,
"error": f"Workspace not found for project: {project_id}",
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
}
git = GitWrapper(workspace.path, _settings) git = GitWrapper(workspace.path, _settings)
result = await git.pull(branch=branch, remote=remote, rebase=rebase) result = await git.pull(branch=branch, remote=remote, rebase=rebase)
@@ -834,7 +1029,11 @@ async def pull(
return {"success": False, "error": e.message, "code": e.code.value} return {"success": False, "error": e.message, "code": e.code.value}
except Exception as e: except Exception as e:
logger.error(f"Unexpected pull error: {e}") logger.error(f"Unexpected pull error: {e}")
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value} return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
# MCP Tools - Diff and Log # MCP Tools - Diff and Log
@@ -854,16 +1053,30 @@ async def diff(
""" """
try: try:
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr] workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
if not workspace: if not workspace:
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value} return {
"success": False,
"error": f"Workspace not found for project: {project_id}",
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
}
git = GitWrapper(workspace.path, _settings) git = GitWrapper(workspace.path, _settings)
result = await git.diff(base=base, head=head, files=files, context_lines=context_lines) result = await git.diff(
base=base, head=head, files=files, context_lines=context_lines
)
return { return {
"success": True, "success": True,
@@ -881,7 +1094,11 @@ async def diff(
return {"success": False, "error": e.message, "code": e.code.value} return {"success": False, "error": e.message, "code": e.code.value}
except Exception as e: except Exception as e:
logger.error(f"Unexpected diff error: {e}") logger.error(f"Unexpected diff error: {e}")
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value} return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@mcp.tool() @mcp.tool()
@@ -898,13 +1115,25 @@ async def log(
""" """
try: try:
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr] workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
if not workspace: if not workspace:
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value} return {
"success": False,
"error": f"Workspace not found for project: {project_id}",
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
}
git = GitWrapper(workspace.path, _settings) git = GitWrapper(workspace.path, _settings)
result = await git.log(ref=ref, limit=limit, skip=skip, path=path) result = await git.log(ref=ref, limit=limit, skip=skip, path=path)
@@ -921,7 +1150,11 @@ async def log(
return {"success": False, "error": e.message, "code": e.code.value} return {"success": False, "error": e.message, "code": e.code.value}
except Exception as e: except Exception as e:
logger.error(f"Unexpected log error: {e}") logger.error(f"Unexpected log error: {e}")
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value} return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
# MCP Tools - Pull Request Operations # MCP Tools - Pull Request Operations
@@ -938,27 +1171,49 @@ async def create_pull_request(
draft: bool = Field(default=False, description="Create as draft"), draft: bool = Field(default=False, description="Create as draft"),
labels: list[str] | None = Field(default=None, description="Labels to add"), labels: list[str] | None = Field(default=None, description="Labels to add"),
assignees: list[str] | None = Field(default=None, description="Users to assign"), assignees: list[str] | None = Field(default=None, description="Users to assign"),
reviewers: list[str] | None = Field(default=None, description="Users to request review from"), reviewers: list[str] | None = Field(
default=None, description="Users to request review from"
),
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Create a pull request on the remote provider. Create a pull request on the remote provider.
""" """
try: try:
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr] workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
if not workspace: if not workspace:
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value} return {
"success": False,
"error": f"Workspace not found for project: {project_id}",
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
}
if not workspace.repo_url: if not workspace.repo_url:
return {"success": False, "error": "Workspace has no repository URL", "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": "Workspace has no repository URL",
"code": ErrorCode.INVALID_REQUEST.value,
}
provider = _get_provider_for_url(workspace.repo_url) provider = _get_provider_for_url(workspace.repo_url)
if not provider: if not provider:
return {"success": False, "error": "No provider configured for this repository", "code": ErrorCode.PROVIDER_NOT_FOUND.value} return {
"success": False,
"error": "No provider configured for this repository",
"code": ErrorCode.PROVIDER_NOT_FOUND.value,
}
owner, repo = provider.parse_repo_url(workspace.repo_url) owner, repo = provider.parse_repo_url(workspace.repo_url)
@@ -987,7 +1242,11 @@ async def create_pull_request(
return {"success": False, "error": e.message, "code": e.code.value} return {"success": False, "error": e.message, "code": e.code.value}
except Exception as e: except Exception as e:
logger.error(f"Unexpected create PR error: {e}") logger.error(f"Unexpected create PR error: {e}")
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value} return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@mcp.tool() @mcp.tool()
@@ -1001,20 +1260,40 @@ async def get_pull_request(
""" """
try: try:
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr] workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
if not workspace: if not workspace:
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value} return {
"success": False,
"error": f"Workspace not found for project: {project_id}",
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
}
if not workspace.repo_url: if not workspace.repo_url:
return {"success": False, "error": "Workspace has no repository URL", "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": "Workspace has no repository URL",
"code": ErrorCode.INVALID_REQUEST.value,
}
provider = _get_provider_for_url(workspace.repo_url) provider = _get_provider_for_url(workspace.repo_url)
if not provider: if not provider:
return {"success": False, "error": "No provider configured for this repository", "code": ErrorCode.PROVIDER_NOT_FOUND.value} return {
"success": False,
"error": "No provider configured for this repository",
"code": ErrorCode.PROVIDER_NOT_FOUND.value,
}
owner, repo = provider.parse_repo_url(workspace.repo_url) owner, repo = provider.parse_repo_url(workspace.repo_url)
result = await provider.get_pr(owner, repo, pr_number) result = await provider.get_pr(owner, repo, pr_number)
@@ -1030,14 +1309,20 @@ async def get_pull_request(
return {"success": False, "error": e.message, "code": e.code.value} return {"success": False, "error": e.message, "code": e.code.value}
except Exception as e: except Exception as e:
logger.error(f"Unexpected get PR error: {e}") logger.error(f"Unexpected get PR error: {e}")
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value} return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@mcp.tool() @mcp.tool()
async def list_pull_requests( async def list_pull_requests(
project_id: str = Field(..., description="Project ID for scoping"), project_id: str = Field(..., description="Project ID for scoping"),
agent_id: str = Field(..., description="Agent ID making the request"), agent_id: str = Field(..., description="Agent ID making the request"),
state: str | None = Field(default=None, description="Filter by state: open, closed, merged"), state: str | None = Field(
default=None, description="Filter by state: open, closed, merged"
),
author: str | None = Field(default=None, description="Filter by author"), author: str | None = Field(default=None, description="Filter by author"),
limit: int = Field(default=20, ge=1, le=100, description="Max PRs"), limit: int = Field(default=20, ge=1, le=100, description="Max PRs"),
) -> dict[str, Any]: ) -> dict[str, Any]:
@@ -1046,20 +1331,40 @@ async def list_pull_requests(
""" """
try: try:
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr] workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
if not workspace: if not workspace:
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value} return {
"success": False,
"error": f"Workspace not found for project: {project_id}",
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
}
if not workspace.repo_url: if not workspace.repo_url:
return {"success": False, "error": "Workspace has no repository URL", "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": "Workspace has no repository URL",
"code": ErrorCode.INVALID_REQUEST.value,
}
provider = _get_provider_for_url(workspace.repo_url) provider = _get_provider_for_url(workspace.repo_url)
if not provider: if not provider:
return {"success": False, "error": "No provider configured for this repository", "code": ErrorCode.PROVIDER_NOT_FOUND.value} return {
"success": False,
"error": "No provider configured for this repository",
"code": ErrorCode.PROVIDER_NOT_FOUND.value,
}
# Parse state # Parse state
pr_state = None pr_state = None
@@ -1067,10 +1372,16 @@ async def list_pull_requests(
try: try:
pr_state = PRState(state.lower()) pr_state = PRState(state.lower())
except ValueError: except ValueError:
return {"success": False, "error": f"Invalid state: {state}. Valid: open, closed, merged", "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": f"Invalid state: {state}. Valid: open, closed, merged",
"code": ErrorCode.INVALID_REQUEST.value,
}
owner, repo = provider.parse_repo_url(workspace.repo_url) owner, repo = provider.parse_repo_url(workspace.repo_url)
result = await provider.list_prs(owner, repo, state=pr_state, author=author, limit=limit) result = await provider.list_prs(
owner, repo, state=pr_state, author=author, limit=limit
)
return { return {
"success": result.success, "success": result.success,
@@ -1084,7 +1395,11 @@ async def list_pull_requests(
return {"success": False, "error": e.message, "code": e.code.value} return {"success": False, "error": e.message, "code": e.code.value}
except Exception as e: except Exception as e:
logger.error(f"Unexpected list PRs error: {e}") logger.error(f"Unexpected list PRs error: {e}")
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value} return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@mcp.tool() @mcp.tool()
@@ -1092,39 +1407,71 @@ async def merge_pull_request(
project_id: str = Field(..., description="Project ID for scoping"), project_id: str = Field(..., description="Project ID for scoping"),
agent_id: str = Field(..., description="Agent ID making the request"), agent_id: str = Field(..., description="Agent ID making the request"),
pr_number: int = Field(..., description="Pull request number"), pr_number: int = Field(..., description="Pull request number"),
merge_strategy: str = Field(default="merge", description="Strategy: merge, squash, rebase"), merge_strategy: str = Field(
commit_message: str | None = Field(default=None, description="Custom commit message"), default="merge", description="Strategy: merge, squash, rebase"
delete_branch: bool = Field(default=True, description="Delete source branch after merge"), ),
commit_message: str | None = Field(
default=None, description="Custom commit message"
),
delete_branch: bool = Field(
default=True, description="Delete source branch after merge"
),
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Merge a pull request. Merge a pull request.
""" """
try: try:
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr] workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
if not workspace: if not workspace:
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value} return {
"success": False,
"error": f"Workspace not found for project: {project_id}",
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
}
if not workspace.repo_url: if not workspace.repo_url:
return {"success": False, "error": "Workspace has no repository URL", "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": "Workspace has no repository URL",
"code": ErrorCode.INVALID_REQUEST.value,
}
provider = _get_provider_for_url(workspace.repo_url) provider = _get_provider_for_url(workspace.repo_url)
if not provider: if not provider:
return {"success": False, "error": "No provider configured for this repository", "code": ErrorCode.PROVIDER_NOT_FOUND.value} return {
"success": False,
"error": "No provider configured for this repository",
"code": ErrorCode.PROVIDER_NOT_FOUND.value,
}
# Parse merge strategy # Parse merge strategy
try: try:
strategy = MergeStrategy(merge_strategy.lower()) strategy = MergeStrategy(merge_strategy.lower())
except ValueError: except ValueError:
return {"success": False, "error": f"Invalid strategy: {merge_strategy}. Valid: merge, squash, rebase", "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": f"Invalid strategy: {merge_strategy}. Valid: merge, squash, rebase",
"code": ErrorCode.INVALID_REQUEST.value,
}
owner, repo = provider.parse_repo_url(workspace.repo_url) owner, repo = provider.parse_repo_url(workspace.repo_url)
result = await provider.merge_pr( result = await provider.merge_pr(
owner, repo, pr_number, owner,
repo,
pr_number,
merge_strategy=strategy, merge_strategy=strategy,
commit_message=commit_message, commit_message=commit_message,
delete_branch=delete_branch, delete_branch=delete_branch,
@@ -1142,7 +1489,11 @@ async def merge_pull_request(
return {"success": False, "error": e.message, "code": e.code.value} return {"success": False, "error": e.message, "code": e.code.value}
except Exception as e: except Exception as e:
logger.error(f"Unexpected merge PR error: {e}") logger.error(f"Unexpected merge PR error: {e}")
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value} return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
# MCP Tools - Workspace Operations # MCP Tools - Workspace Operations
@@ -1158,13 +1509,25 @@ async def get_workspace(
""" """
try: try:
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr] workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
if not workspace: if not workspace:
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value} return {
"success": False,
"error": f"Workspace not found for project: {project_id}",
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
}
return { return {
"success": True, "success": True,
@@ -1176,14 +1539,20 @@ async def get_workspace(
return {"success": False, "error": e.message, "code": e.code.value} return {"success": False, "error": e.message, "code": e.code.value}
except Exception as e: except Exception as e:
logger.error(f"Unexpected get workspace error: {e}") logger.error(f"Unexpected get workspace error: {e}")
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value} return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@mcp.tool() @mcp.tool()
async def lock_workspace( async def lock_workspace(
project_id: str = Field(..., description="Project ID"), project_id: str = Field(..., description="Project ID"),
agent_id: str = Field(..., description="Agent ID requesting lock"), agent_id: str = Field(..., description="Agent ID requesting lock"),
timeout: int = Field(default=300, ge=10, le=3600, description="Lock timeout in seconds"), timeout: int = Field(
default=300, ge=10, le=3600, description="Lock timeout in seconds"
),
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Acquire a lock on a workspace. Acquire a lock on a workspace.
@@ -1192,9 +1561,17 @@ async def lock_workspace(
""" """
try: try:
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
success = await _workspace_manager.lock_workspace(project_id, agent_id, timeout) # type: ignore[union-attr] success = await _workspace_manager.lock_workspace(project_id, agent_id, timeout) # type: ignore[union-attr]
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr] workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
@@ -1202,7 +1579,9 @@ async def lock_workspace(
return { return {
"success": success, "success": success,
"lock_holder": workspace.lock_holder if workspace else None, "lock_holder": workspace.lock_holder if workspace else None,
"lock_expires": workspace.lock_expires.isoformat() if workspace and workspace.lock_expires else None, "lock_expires": workspace.lock_expires.isoformat()
if workspace and workspace.lock_expires
else None,
} }
except GitOpsError as e: except GitOpsError as e:
@@ -1210,7 +1589,11 @@ async def lock_workspace(
return {"success": False, "error": e.message, "code": e.code.value} return {"success": False, "error": e.message, "code": e.code.value}
except Exception as e: except Exception as e:
logger.error(f"Unexpected lock workspace error: {e}") logger.error(f"Unexpected lock workspace error: {e}")
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value} return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
@mcp.tool() @mcp.tool()
@@ -1224,9 +1607,17 @@ async def unlock_workspace(
""" """
try: try:
if error := _validate_id(project_id, "project_id"): if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"): if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value} return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
success = await _workspace_manager.unlock_workspace(project_id, agent_id, force) # type: ignore[union-attr] success = await _workspace_manager.unlock_workspace(project_id, agent_id, force) # type: ignore[union-attr]
@@ -1237,7 +1628,11 @@ async def unlock_workspace(
return {"success": False, "error": e.message, "code": e.code.value} return {"success": False, "error": e.message, "code": e.code.value}
except Exception as e: except Exception as e:
logger.error(f"Unexpected unlock workspace error: {e}") logger.error(f"Unexpected unlock workspace error: {e}")
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value} return {
"success": False,
"error": str(e),
"code": ErrorCode.INTERNAL_ERROR.value,
}
# Register all tools # Register all tools

View File

@@ -13,7 +13,7 @@ from datetime import UTC, datetime, timedelta
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import aiofiles import aiofiles # type: ignore[import-untyped]
from filelock import FileLock, Timeout from filelock import FileLock, Timeout
from config import Settings, get_settings from config import Settings, get_settings
@@ -54,10 +54,25 @@ class WorkspaceManager:
self.base_path.mkdir(parents=True, exist_ok=True) self.base_path.mkdir(parents=True, exist_ok=True)
def _get_workspace_path(self, project_id: str) -> Path: def _get_workspace_path(self, project_id: str) -> Path:
"""Get the path for a project workspace.""" """Get the path for a project workspace with path traversal protection."""
# Sanitize project ID for filesystem # Sanitize project ID for filesystem
safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_id) safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_id)
return self.base_path / safe_id
# Reject reserved names
reserved_names = {".", "..", "con", "prn", "aux", "nul"}
if safe_id.lower() in reserved_names:
raise ValueError(f"Invalid project ID: reserved name '{project_id}'")
# Construct path and verify it's within base_path (prevent path traversal)
workspace_path = (self.base_path / safe_id).resolve()
base_resolved = self.base_path.resolve()
if not workspace_path.is_relative_to(base_resolved):
raise ValueError(
f"Invalid project ID: path traversal detected '{project_id}'"
)
return workspace_path
def _get_lock_path(self, project_id: str) -> Path: def _get_lock_path(self, project_id: str) -> Path:
"""Get the lock file path for a workspace.""" """Get the lock file path for a workspace."""
@@ -230,10 +245,7 @@ class WorkspaceManager:
raise WorkspaceNotFoundError(project_id) raise WorkspaceNotFoundError(project_id)
# Check if already locked by someone else # Check if already locked by someone else
if ( if workspace.state == WorkspaceState.LOCKED and workspace.lock_holder != holder:
workspace.state == WorkspaceState.LOCKED
and workspace.lock_holder != holder
):
# Check if lock expired # Check if lock expired
if workspace.lock_expires and workspace.lock_expires > datetime.now(UTC): if workspace.lock_expires and workspace.lock_expires > datetime.now(UTC):
raise WorkspaceLockedError(project_id, workspace.lock_holder) raise WorkspaceLockedError(project_id, workspace.lock_holder)
@@ -275,11 +287,7 @@ class WorkspaceManager:
raise WorkspaceNotFoundError(project_id) raise WorkspaceNotFoundError(project_id)
# Verify holder # Verify holder
if ( if not force and workspace.lock_holder and workspace.lock_holder != holder:
not force
and workspace.lock_holder
and workspace.lock_holder != holder
):
raise WorkspaceLockedError(project_id, workspace.lock_holder) raise WorkspaceLockedError(project_id, workspace.lock_holder)
# Clear lock # Clear lock
@@ -362,7 +370,7 @@ class WorkspaceManager:
Returns: Returns:
List of WorkspaceInfo List of WorkspaceInfo
""" """
workspaces = [] workspaces: list[WorkspaceInfo] = []
if not self.base_path.exists(): if not self.base_path.exists():
return workspaces return workspaces
@@ -532,9 +540,7 @@ class WorkspaceLock:
self.holder, self.holder,
) )
except Exception as e: except Exception as e:
logger.warning( logger.warning(f"Failed to release lock for {self.project_id}: {e}")
f"Failed to release lock for {self.project_id}: {e}"
)
class FileLockManager: class FileLockManager: