forked from cardosofelipe/fast-next-template
**feat(git-ops): enhance MCP server with Git provider updates and SSRF protection**
- Added `mcp-git-ops` service to `docker-compose.dev.yml` with health checks and configurations. - Integrated SSRF protection in repository URL validation for enhanced security. - Expanded `pyproject.toml` mypy settings and adjusted code to meet stricter type checking. - Improved workspace management and GitWrapper operations with error handling refinements. - Updated input validation, branching, and repository operations to align with new error structure. - Shut down thread pool executor gracefully during server cleanup.
This commit is contained in:
13
Makefile
13
Makefile
@@ -47,6 +47,7 @@ help:
|
||||
@echo " cd backend && make help - Backend-specific commands"
|
||||
@echo " cd mcp-servers/llm-gateway && make - LLM Gateway commands"
|
||||
@echo " cd mcp-servers/knowledge-base && make - Knowledge Base commands"
|
||||
@echo " cd mcp-servers/git-ops && make - Git Operations commands"
|
||||
@echo " cd frontend && npm run - Frontend-specific commands"
|
||||
|
||||
# ============================================================================
|
||||
@@ -138,6 +139,9 @@ test-mcp:
|
||||
@echo ""
|
||||
@echo "=== Knowledge Base ==="
|
||||
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v
|
||||
@echo ""
|
||||
@echo "=== Git Operations ==="
|
||||
@cd mcp-servers/git-ops && IS_TEST=True uv run pytest tests/ -v
|
||||
|
||||
test-frontend:
|
||||
@echo "Running frontend tests..."
|
||||
@@ -158,6 +162,9 @@ test-cov:
|
||||
@echo ""
|
||||
@echo "=== Knowledge Base Coverage ==="
|
||||
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
||||
@echo ""
|
||||
@echo "=== Git Operations Coverage ==="
|
||||
@cd mcp-servers/git-ops && IS_TEST=True uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
||||
|
||||
test-integration:
|
||||
@echo "Running MCP integration tests..."
|
||||
@@ -178,6 +185,9 @@ format-all:
|
||||
@echo "Formatting Knowledge Base..."
|
||||
@cd mcp-servers/knowledge-base && make format
|
||||
@echo ""
|
||||
@echo "Formatting Git Operations..."
|
||||
@cd mcp-servers/git-ops && make format
|
||||
@echo ""
|
||||
@echo "Formatting frontend..."
|
||||
@cd frontend && npm run format
|
||||
@echo ""
|
||||
@@ -197,6 +207,9 @@ validate:
|
||||
@echo "Validating Knowledge Base..."
|
||||
@cd mcp-servers/knowledge-base && make validate
|
||||
@echo ""
|
||||
@echo "Validating Git Operations..."
|
||||
@cd mcp-servers/git-ops && make validate
|
||||
@echo ""
|
||||
@echo "All validations passed!"
|
||||
|
||||
validate-all: validate
|
||||
|
||||
@@ -96,6 +96,38 @@ services:
|
||||
- app-network
|
||||
restart: unless-stopped
|
||||
|
||||
mcp-git-ops:
|
||||
build:
|
||||
context: ./mcp-servers/git-ops
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "8003:8003"
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
# GIT_OPS_ prefix required by pydantic-settings config
|
||||
- GIT_OPS_HOST=0.0.0.0
|
||||
- GIT_OPS_PORT=8003
|
||||
- GIT_OPS_REDIS_URL=redis://redis:6379/3
|
||||
- GIT_OPS_GITEA_BASE_URL=${GITEA_BASE_URL}
|
||||
- GIT_OPS_GITEA_TOKEN=${GITEA_TOKEN}
|
||||
- GIT_OPS_GITHUB_TOKEN=${GITHUB_TOKEN}
|
||||
- ENVIRONMENT=development
|
||||
volumes:
|
||||
- git_workspaces_dev:/workspaces
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8003/health').raise_for_status()"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
networks:
|
||||
- app-network
|
||||
restart: unless-stopped
|
||||
|
||||
backend:
|
||||
build:
|
||||
context: ./backend
|
||||
@@ -119,6 +151,7 @@ services:
|
||||
# MCP Server URLs
|
||||
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
||||
- GIT_OPS_URL=http://mcp-git-ops:8003
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
@@ -128,6 +161,8 @@ services:
|
||||
condition: service_healthy
|
||||
mcp-knowledge-base:
|
||||
condition: service_healthy
|
||||
mcp-git-ops:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 10s
|
||||
@@ -155,6 +190,7 @@ services:
|
||||
# MCP Server URLs (agents need access to MCP)
|
||||
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
||||
- GIT_OPS_URL=http://mcp-git-ops:8003
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
@@ -164,6 +200,8 @@ services:
|
||||
condition: service_healthy
|
||||
mcp-knowledge-base:
|
||||
condition: service_healthy
|
||||
mcp-git-ops:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- app-network
|
||||
command: ["celery", "-A", "app.celery_app", "worker", "-Q", "agent", "-l", "info", "-c", "4"]
|
||||
@@ -181,11 +219,14 @@ services:
|
||||
- DATABASE_URL=${DATABASE_URL}
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
- CELERY_QUEUE=git
|
||||
- GIT_OPS_URL=http://mcp-git-ops:8003
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
mcp-git-ops:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- app-network
|
||||
command: ["celery", "-A", "app.celery_app", "worker", "-Q", "git", "-l", "info", "-c", "2"]
|
||||
@@ -260,6 +301,7 @@ services:
|
||||
volumes:
|
||||
postgres_data_dev:
|
||||
redis_data_dev:
|
||||
git_workspaces_dev:
|
||||
frontend_dev_modules:
|
||||
frontend_dev_next:
|
||||
|
||||
|
||||
88
mcp-servers/git-ops/Makefile
Normal file
88
mcp-servers/git-ops/Makefile
Normal file
@@ -0,0 +1,88 @@
|
||||
.PHONY: help install install-dev lint lint-fix format format-check type-check test test-cov validate clean run
|
||||
|
||||
# Ensure commands in this project don't inherit an external Python virtualenv
|
||||
# (prevents uv warnings about mismatched VIRTUAL_ENV when running from repo root)
|
||||
unexport VIRTUAL_ENV
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@echo "Git Operations MCP Server - Development Commands"
|
||||
@echo ""
|
||||
@echo "Setup:"
|
||||
@echo " make install - Install production dependencies"
|
||||
@echo " make install-dev - Install development dependencies"
|
||||
@echo ""
|
||||
@echo "Quality Checks:"
|
||||
@echo " make lint - Run Ruff linter"
|
||||
@echo " make lint-fix - Run Ruff linter with auto-fix"
|
||||
@echo " make format - Format code with Ruff"
|
||||
@echo " make format-check - Check if code is formatted"
|
||||
@echo " make type-check - Run mypy type checker"
|
||||
@echo ""
|
||||
@echo "Testing:"
|
||||
@echo " make test - Run pytest"
|
||||
@echo " make test-cov - Run pytest with coverage"
|
||||
@echo ""
|
||||
@echo "All-in-one:"
|
||||
@echo " make validate - Run all checks (lint + format + types)"
|
||||
@echo ""
|
||||
@echo "Running:"
|
||||
@echo " make run - Run the server locally"
|
||||
@echo ""
|
||||
@echo "Cleanup:"
|
||||
@echo " make clean - Remove cache and build artifacts"
|
||||
|
||||
# Setup
|
||||
install:
|
||||
@echo "Installing production dependencies..."
|
||||
@uv pip install -e .
|
||||
|
||||
install-dev:
|
||||
@echo "Installing development dependencies..."
|
||||
@uv pip install -e ".[dev]"
|
||||
|
||||
# Quality checks
|
||||
lint:
|
||||
@echo "Running Ruff linter..."
|
||||
@uv run ruff check .
|
||||
|
||||
lint-fix:
|
||||
@echo "Running Ruff linter with auto-fix..."
|
||||
@uv run ruff check --fix .
|
||||
|
||||
format:
|
||||
@echo "Formatting code..."
|
||||
@uv run ruff format .
|
||||
|
||||
format-check:
|
||||
@echo "Checking code formatting..."
|
||||
@uv run ruff format --check .
|
||||
|
||||
type-check:
|
||||
@echo "Running mypy..."
|
||||
@uv run python -m mypy server.py config.py models.py exceptions.py git_wrapper.py workspace.py providers/ --explicit-package-bases
|
||||
|
||||
# Testing
|
||||
test:
|
||||
@echo "Running tests..."
|
||||
@IS_TEST=True uv run pytest tests/ -v
|
||||
|
||||
test-cov:
|
||||
@echo "Running tests with coverage..."
|
||||
@IS_TEST=True uv run pytest tests/ -v --cov=. --cov-report=term-missing --cov-report=html
|
||||
|
||||
# All-in-one validation
|
||||
validate: lint format-check type-check
|
||||
@echo "All validations passed!"
|
||||
|
||||
# Running
|
||||
run:
|
||||
@echo "Starting Git Operations server..."
|
||||
@uv run python server.py
|
||||
|
||||
# Cleanup
|
||||
clean:
|
||||
@echo "Cleaning up..."
|
||||
@rm -rf __pycache__ .pytest_cache .mypy_cache .ruff_cache .coverage htmlcov
|
||||
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||
@@ -73,7 +73,7 @@ class GitOpsError(Exception):
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for MCP response."""
|
||||
result = {
|
||||
result: dict[str, Any] = {
|
||||
"error": self.message,
|
||||
"code": self.code.value,
|
||||
}
|
||||
@@ -325,9 +325,7 @@ class PRNotFoundError(PRError):
|
||||
class APIError(ProviderError):
|
||||
"""Provider API error."""
|
||||
|
||||
def __init__(
|
||||
self, provider: str, status_code: int, message: str
|
||||
) -> None:
|
||||
def __init__(self, provider: str, status_code: int, message: str) -> None:
|
||||
super().__init__(
|
||||
f"{provider} API error ({status_code}): {message}",
|
||||
ErrorCode.API_ERROR,
|
||||
|
||||
@@ -52,6 +52,21 @@ from models import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def sanitize_url_for_logging(url: str) -> str:
|
||||
"""
|
||||
Remove any credentials from a URL before logging.
|
||||
|
||||
Handles URLs like:
|
||||
- https://token@github.com/owner/repo.git
|
||||
- https://user:password@github.com/owner/repo.git
|
||||
- git@github.com:owner/repo.git (unchanged, no credentials)
|
||||
"""
|
||||
# Pattern to match https://[credentials@]host/path
|
||||
sanitized = re.sub(r"(https?://)([^@]+@)", r"\1***@", url)
|
||||
return sanitized
|
||||
|
||||
|
||||
# Thread pool for blocking git operations
|
||||
_executor: ThreadPoolExecutor | None = None
|
||||
|
||||
@@ -81,7 +96,7 @@ class GitWrapper:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_path: Path,
|
||||
workspace_path: Path | str,
|
||||
settings: Settings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -91,7 +106,9 @@ class GitWrapper:
|
||||
workspace_path: Path to the git workspace
|
||||
settings: Optional settings override
|
||||
"""
|
||||
self.workspace_path = workspace_path
|
||||
self.workspace_path = (
|
||||
Path(workspace_path) if isinstance(workspace_path, str) else workspace_path
|
||||
)
|
||||
self.settings = settings or get_settings()
|
||||
self._repo: GitRepo | None = None
|
||||
|
||||
@@ -175,8 +192,10 @@ class GitWrapper:
|
||||
)
|
||||
|
||||
except GitCommandError as e:
|
||||
logger.error(f"Clone failed: {e}")
|
||||
raise CloneError(repo_url, str(e))
|
||||
# Sanitize URLs in error messages to prevent credential leakage
|
||||
error_msg = sanitize_url_for_logging(str(e))
|
||||
logger.error(f"Clone failed: {error_msg}")
|
||||
raise CloneError(sanitize_url_for_logging(repo_url), error_msg)
|
||||
|
||||
return await run_in_executor(_do_clone)
|
||||
|
||||
@@ -200,9 +219,10 @@ class GitWrapper:
|
||||
staged = []
|
||||
for diff in repo.index.diff("HEAD"):
|
||||
change_type = self._diff_to_change_type(diff.change_type)
|
||||
path = diff.b_path or diff.a_path or ""
|
||||
staged.append(
|
||||
FileChange(
|
||||
path=diff.b_path or diff.a_path,
|
||||
path=path,
|
||||
change_type=change_type,
|
||||
old_path=diff.a_path if diff.renamed else None,
|
||||
).to_dict()
|
||||
@@ -212,9 +232,10 @@ class GitWrapper:
|
||||
unstaged = []
|
||||
for diff in repo.index.diff(None):
|
||||
change_type = self._diff_to_change_type(diff.change_type)
|
||||
path = diff.b_path or diff.a_path or ""
|
||||
unstaged.append(
|
||||
FileChange(
|
||||
path=diff.b_path or diff.a_path,
|
||||
path=path,
|
||||
change_type=change_type,
|
||||
).to_dict()
|
||||
)
|
||||
@@ -228,10 +249,18 @@ class GitWrapper:
|
||||
tracking = repo.active_branch.tracking_branch()
|
||||
if tracking:
|
||||
ahead = len(
|
||||
list(repo.iter_commits(f"{tracking.name}..{repo.active_branch.name}"))
|
||||
list(
|
||||
repo.iter_commits(
|
||||
f"{tracking.name}..{repo.active_branch.name}"
|
||||
)
|
||||
)
|
||||
)
|
||||
behind = len(
|
||||
list(repo.iter_commits(f"{repo.active_branch.name}..{tracking.name}"))
|
||||
list(
|
||||
repo.iter_commits(
|
||||
f"{repo.active_branch.name}..{tracking.name}"
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass # No tracking branch
|
||||
@@ -361,6 +390,12 @@ class GitWrapper:
|
||||
local_branches = []
|
||||
for branch in repo.branches:
|
||||
tracking = branch.tracking_branch()
|
||||
msg = branch.commit.message
|
||||
commit_msg = (
|
||||
msg.decode("utf-8", errors="replace")
|
||||
if isinstance(msg, bytes)
|
||||
else msg
|
||||
).split("\n")[0]
|
||||
local_branches.append(
|
||||
BranchInfo(
|
||||
name=branch.name,
|
||||
@@ -368,7 +403,7 @@ class GitWrapper:
|
||||
is_remote=False,
|
||||
tracking_branch=tracking.name if tracking else None,
|
||||
commit_sha=branch.commit.hexsha,
|
||||
commit_message=branch.commit.message.split("\n")[0],
|
||||
commit_message=commit_msg,
|
||||
).to_dict()
|
||||
)
|
||||
|
||||
@@ -379,13 +414,19 @@ class GitWrapper:
|
||||
# Skip HEAD refs
|
||||
if ref.name.endswith("/HEAD"):
|
||||
continue
|
||||
msg = ref.commit.message
|
||||
commit_msg = (
|
||||
msg.decode("utf-8", errors="replace")
|
||||
if isinstance(msg, bytes)
|
||||
else msg
|
||||
).split("\n")[0]
|
||||
remote_branches.append(
|
||||
BranchInfo(
|
||||
name=ref.name,
|
||||
is_current=False,
|
||||
is_remote=True,
|
||||
commit_sha=ref.commit.hexsha,
|
||||
commit_message=ref.commit.message.split("\n")[0],
|
||||
commit_message=commit_msg,
|
||||
).to_dict()
|
||||
)
|
||||
|
||||
@@ -485,7 +526,11 @@ class GitWrapper:
|
||||
repo.git.add("-A")
|
||||
|
||||
# Check if there's anything to commit
|
||||
if not allow_empty and not repo.index.diff("HEAD") and not repo.untracked_files:
|
||||
if (
|
||||
not allow_empty
|
||||
and not repo.index.diff("HEAD")
|
||||
and not repo.untracked_files
|
||||
):
|
||||
raise CommitError("Nothing to commit")
|
||||
|
||||
# Build author
|
||||
@@ -613,7 +658,7 @@ class GitWrapper:
|
||||
|
||||
try:
|
||||
# Build push info
|
||||
push_info_list = []
|
||||
push_info_list: list[Any] = []
|
||||
|
||||
if remote not in [r.name for r in repo.remotes]:
|
||||
raise PushError(push_branch, f"Remote not found: {remote}")
|
||||
@@ -716,9 +761,7 @@ class GitWrapper:
|
||||
repo.git.pull(remote, pull_branch)
|
||||
|
||||
# Count new commits
|
||||
commits_received = len(
|
||||
list(repo.iter_commits(f"{head_before}..HEAD"))
|
||||
)
|
||||
commits_received = len(list(repo.iter_commits(f"{head_before}..HEAD")))
|
||||
|
||||
# Check if fast-forward
|
||||
fast_forward = commits_received > 0 and not repo.head.commit.parents
|
||||
@@ -733,10 +776,9 @@ class GitWrapper:
|
||||
except GitCommandError as e:
|
||||
error_msg = str(e)
|
||||
if "conflict" in error_msg.lower():
|
||||
# Get conflicting files
|
||||
# Get conflicting files - keys are paths directly
|
||||
conflicts = [
|
||||
item.a_path
|
||||
for item in repo.index.unmerged_blobs().keys()
|
||||
str(path) for path in repo.index.unmerged_blobs().keys()
|
||||
]
|
||||
raise MergeConflictError(conflicts)
|
||||
raise PullError(pull_branch, error_msg)
|
||||
@@ -823,7 +865,7 @@ class GitWrapper:
|
||||
continue
|
||||
|
||||
change_type = self._diff_to_change_type(diff.change_type)
|
||||
path = diff.b_path or diff.a_path
|
||||
path = diff.b_path or diff.a_path or ""
|
||||
|
||||
# Parse hunks from patch
|
||||
hunks = []
|
||||
@@ -831,7 +873,13 @@ class GitWrapper:
|
||||
deletions = 0
|
||||
|
||||
if diff.diff:
|
||||
patch_text = diff.diff.decode("utf-8", errors="replace")
|
||||
# Handle both bytes and str
|
||||
raw_diff = diff.diff
|
||||
patch_text = (
|
||||
raw_diff.decode("utf-8", errors="replace")
|
||||
if isinstance(raw_diff, bytes)
|
||||
else raw_diff
|
||||
)
|
||||
# Parse hunks (simplified)
|
||||
for line in patch_text.split("\n"):
|
||||
if line.startswith("+") and not line.startswith("+++"):
|
||||
@@ -921,18 +969,25 @@ class GitWrapper:
|
||||
iterator = repo.iter_commits(**kwargs)
|
||||
|
||||
for commit in iterator:
|
||||
# Handle message that can be bytes
|
||||
msg = commit.message
|
||||
message_str = (
|
||||
msg.decode("utf-8", errors="replace")
|
||||
if isinstance(msg, bytes)
|
||||
else msg
|
||||
)
|
||||
commits.append(
|
||||
CommitInfo(
|
||||
sha=commit.hexsha,
|
||||
short_sha=commit.hexsha[:7],
|
||||
message=commit.message,
|
||||
author_name=commit.author.name,
|
||||
author_email=commit.author.email,
|
||||
message=message_str,
|
||||
author_name=commit.author.name or "Unknown",
|
||||
author_email=commit.author.email or "",
|
||||
authored_date=datetime.fromtimestamp(
|
||||
commit.authored_date, tz=UTC
|
||||
),
|
||||
committer_name=commit.committer.name,
|
||||
committer_email=commit.committer.email,
|
||||
committer_name=commit.committer.name or "Unknown",
|
||||
committer_email=commit.committer.email or "",
|
||||
committed_date=datetime.fromtimestamp(
|
||||
commit.committed_date, tz=UTC
|
||||
),
|
||||
@@ -1052,8 +1107,10 @@ class GitWrapper:
|
||||
|
||||
# Utility methods
|
||||
|
||||
def _diff_to_change_type(self, change_type: str) -> FileChangeType:
|
||||
def _diff_to_change_type(self, change_type: str | None) -> FileChangeType:
|
||||
"""Convert GitPython change type to our enum."""
|
||||
if change_type is None:
|
||||
return FileChangeType.MODIFIED
|
||||
mapping = {
|
||||
"A": FileChangeType.ADDED,
|
||||
"M": FileChangeType.MODIFIED,
|
||||
@@ -1105,7 +1162,8 @@ class GitWrapper:
|
||||
try:
|
||||
cr = repo.config_reader()
|
||||
section, option = key.rsplit(".", 1)
|
||||
return cr.get_value(section, option)
|
||||
value = cr.get_value(section, option)
|
||||
return str(value) if value is not None else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@@ -257,7 +257,9 @@ class WorkspaceInfo:
|
||||
"last_accessed": self.last_accessed.isoformat(),
|
||||
"size_bytes": self.size_bytes,
|
||||
"lock_holder": self.lock_holder,
|
||||
"lock_expires": self.lock_expires.isoformat() if self.lock_expires else None,
|
||||
"lock_expires": self.lock_expires.isoformat()
|
||||
if self.lock_expires
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
@@ -270,7 +272,9 @@ class CloneRequest(BaseModel):
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
repo_url: str = Field(..., description="Repository URL to clone")
|
||||
branch: str | None = Field(default=None, description="Branch to checkout after clone")
|
||||
branch: str | None = Field(
|
||||
default=None, description="Branch to checkout after clone"
|
||||
)
|
||||
depth: int | None = Field(
|
||||
default=None, ge=1, description="Shallow clone depth (None = full clone)"
|
||||
)
|
||||
@@ -407,7 +411,9 @@ class PushRequest(BaseModel):
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
branch: str | None = Field(default=None, description="Branch to push (None = current)")
|
||||
branch: str | None = Field(
|
||||
default=None, description="Branch to push (None = current)"
|
||||
)
|
||||
remote: str = Field(default="origin", description="Remote name")
|
||||
force: bool = Field(default=False, description="Force push")
|
||||
set_upstream: bool = Field(default=True, description="Set upstream tracking")
|
||||
@@ -428,7 +434,9 @@ class PullRequest(BaseModel):
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
branch: str | None = Field(default=None, description="Branch to pull (None = current)")
|
||||
branch: str | None = Field(
|
||||
default=None, description="Branch to pull (None = current)"
|
||||
)
|
||||
remote: str = Field(default="origin", description="Remote name")
|
||||
rebase: bool = Field(default=False, description="Rebase instead of merge")
|
||||
|
||||
@@ -451,7 +459,9 @@ class DiffRequest(BaseModel):
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
base: str | None = Field(default=None, description="Base reference (None = working tree)")
|
||||
base: str | None = Field(
|
||||
default=None, description="Base reference (None = working tree)"
|
||||
)
|
||||
head: str | None = Field(default=None, description="Head reference (None = HEAD)")
|
||||
files: list[str] | None = Field(default=None, description="Specific files to diff")
|
||||
context_lines: int = Field(default=3, ge=0, description="Context lines")
|
||||
@@ -463,9 +473,7 @@ class DiffResult(BaseModel):
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
base: str | None = Field(default=None, description="Base reference")
|
||||
head: str | None = Field(default=None, description="Head reference")
|
||||
files: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="File diffs"
|
||||
)
|
||||
files: list[dict[str, Any]] = Field(default_factory=list, description="File diffs")
|
||||
total_additions: int = Field(default=0, description="Total lines added")
|
||||
total_deletions: int = Field(default=0, description="Total lines removed")
|
||||
files_changed: int = Field(default=0, description="Number of files changed")
|
||||
@@ -549,9 +557,7 @@ class ListPRsResult(BaseModel):
|
||||
"""Result of listing pull requests."""
|
||||
|
||||
success: bool = Field(..., description="Whether list succeeded")
|
||||
pull_requests: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="PRs"
|
||||
)
|
||||
pull_requests: list[dict[str, Any]] = Field(default_factory=list, description="PRs")
|
||||
total_count: int = Field(default=0, description="Total matching PRs")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
@@ -578,7 +584,9 @@ class MergePRResult(BaseModel):
|
||||
|
||||
success: bool = Field(..., description="Whether merge succeeded")
|
||||
merge_commit_sha: str | None = Field(default=None, description="Merge commit SHA")
|
||||
branch_deleted: bool = Field(default=False, description="Whether branch was deleted")
|
||||
branch_deleted: bool = Field(
|
||||
default=False, description="Whether branch was deleted"
|
||||
)
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
@@ -626,7 +634,9 @@ class LockWorkspaceRequest(BaseModel):
|
||||
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
agent_id: str = Field(..., description="Agent ID requesting lock")
|
||||
timeout: int = Field(default=300, ge=10, le=3600, description="Lock timeout seconds")
|
||||
timeout: int = Field(
|
||||
default=300, ge=10, le=3600, description="Lock timeout seconds"
|
||||
)
|
||||
|
||||
|
||||
class LockWorkspaceResult(BaseModel):
|
||||
@@ -634,7 +644,9 @@ class LockWorkspaceResult(BaseModel):
|
||||
|
||||
success: bool = Field(..., description="Whether lock acquired")
|
||||
lock_holder: str | None = Field(default=None, description="Current lock holder")
|
||||
lock_expires: str | None = Field(default=None, description="Lock expiry ISO timestamp")
|
||||
lock_expires: str | None = Field(
|
||||
default=None, description="Lock expiry ISO timestamp"
|
||||
)
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
|
||||
@@ -44,9 +44,7 @@ class BaseProvider(ABC):
|
||||
# Repository operations
|
||||
|
||||
@abstractmethod
|
||||
async def get_repo_info(
|
||||
self, owner: str, repo: str
|
||||
) -> dict[str, Any]:
|
||||
async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get repository information.
|
||||
|
||||
@@ -60,9 +58,7 @@ class BaseProvider(ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_default_branch(
|
||||
self, owner: str, repo: str
|
||||
) -> str:
|
||||
async def get_default_branch(self, owner: str, repo: str) -> str:
|
||||
"""
|
||||
Get the default branch for a repository.
|
||||
|
||||
@@ -112,9 +108,7 @@ class BaseProvider(ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_pr(
|
||||
self, owner: str, repo: str, pr_number: int
|
||||
) -> GetPRResult:
|
||||
async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult:
|
||||
"""
|
||||
Get a pull request by number.
|
||||
|
||||
@@ -209,9 +203,7 @@ class BaseProvider(ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def close_pr(
|
||||
self, owner: str, repo: str, pr_number: int
|
||||
) -> UpdatePRResult:
|
||||
async def close_pr(self, owner: str, repo: str, pr_number: int) -> UpdatePRResult:
|
||||
"""
|
||||
Close a pull request without merging.
|
||||
|
||||
@@ -228,9 +220,7 @@ class BaseProvider(ABC):
|
||||
# Branch operations via API (for operations that need to bypass local git)
|
||||
|
||||
@abstractmethod
|
||||
async def delete_remote_branch(
|
||||
self, owner: str, repo: str, branch: str
|
||||
) -> bool:
|
||||
async def delete_remote_branch(self, owner: str, repo: str, branch: str) -> bool:
|
||||
"""
|
||||
Delete a remote branch via API.
|
||||
|
||||
@@ -379,9 +369,7 @@ class BaseProvider(ABC):
|
||||
return ssh_match.group(1), ssh_match.group(2)
|
||||
|
||||
# Handle HTTPS URLs: https://host/owner/repo.git
|
||||
https_match = re.match(
|
||||
r"https?://[^/]+/([^/]+)/([^/]+?)(?:\.git)?$", repo_url
|
||||
)
|
||||
https_match = re.match(r"https?://[^/]+/([^/]+)/([^/]+?)(?:\.git)?$", repo_url)
|
||||
if https_match:
|
||||
return https_match.group(1), https_match.group(2)
|
||||
|
||||
|
||||
@@ -116,9 +116,7 @@ class GitHubProvider(BaseProvider):
|
||||
if response.status_code == 403:
|
||||
# Check for rate limiting
|
||||
if "rate limit" in response.text.lower():
|
||||
raise APIError(
|
||||
"github", 403, "GitHub API rate limit exceeded"
|
||||
)
|
||||
raise APIError("github", 403, "GitHub API rate limit exceeded")
|
||||
raise AuthenticationError(
|
||||
"github", "Insufficient permissions for this operation"
|
||||
)
|
||||
|
||||
@@ -101,7 +101,7 @@ exclude_lines = [
|
||||
"if TYPE_CHECKING:",
|
||||
"if __name__ == .__main__.:",
|
||||
]
|
||||
fail_under = 65 # TODO: Increase to 80% once more tool tests are added
|
||||
fail_under = 78
|
||||
show_missing = true
|
||||
|
||||
[tool.mypy]
|
||||
@@ -111,6 +111,8 @@ warn_unused_ignores = false
|
||||
disallow_untyped_defs = true
|
||||
ignore_missing_imports = true
|
||||
plugins = ["pydantic.mypy"]
|
||||
files = ["server.py", "config.py", "models.py", "exceptions.py", "git_wrapper.py", "workspace.py", "providers/"]
|
||||
exclude = ["tests/"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "tests.*"
|
||||
|
||||
@@ -56,13 +56,56 @@ def _validate_branch(value: str) -> str | None:
|
||||
|
||||
|
||||
def _validate_url(value: str) -> str | None:
|
||||
"""Validate repository URL format."""
|
||||
"""
|
||||
Validate repository URL format with SSRF protection.
|
||||
|
||||
Validates the URL format and optionally checks against allowed hosts.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return "Repository URL must be a string"
|
||||
if not value:
|
||||
return "Repository URL is required"
|
||||
if not URL_PATTERN.match(value):
|
||||
return "Invalid repository URL: must be a valid HTTPS or SSH git URL"
|
||||
|
||||
# SSRF protection: check allowed hosts if configured
|
||||
if _settings and _settings.allowed_hosts:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(value)
|
||||
hostname = parsed.hostname
|
||||
|
||||
# Block localhost and loopback addresses
|
||||
blocked_hosts = {
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"::1",
|
||||
"0.0.0.0",
|
||||
"169.254.169.254", # Cloud metadata endpoint
|
||||
}
|
||||
|
||||
if hostname and hostname.lower() in blocked_hosts:
|
||||
return f"Repository URL not allowed: blocked host '{hostname}'"
|
||||
|
||||
# Block private IP ranges (simplified check)
|
||||
if hostname:
|
||||
import ipaddress
|
||||
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
if ip.is_private or ip.is_loopback or ip.is_link_local:
|
||||
return (
|
||||
f"Repository URL not allowed: private IP address '{hostname}'"
|
||||
)
|
||||
except ValueError:
|
||||
pass # Not an IP address, continue with hostname check
|
||||
|
||||
# Check against allowed hosts list
|
||||
if hostname and hostname.lower() not in [
|
||||
h.lower() for h in _settings.allowed_hosts
|
||||
]:
|
||||
return f"Repository URL not allowed: host '{hostname}' not in allowed list"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -92,7 +135,8 @@ def _get_provider_for_url(repo_url: str) -> BaseProvider | None:
|
||||
if "github.com" in url_lower or (
|
||||
_settings.github_api_url
|
||||
and _settings.github_api_url != "https://api.github.com"
|
||||
and _settings.github_api_url.replace("https://", "").replace("/api/v3", "") in url_lower
|
||||
and _settings.github_api_url.replace("https://", "").replace("/api/v3", "")
|
||||
in url_lower
|
||||
):
|
||||
return _github_provider
|
||||
|
||||
@@ -166,6 +210,13 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
|
||||
if _github_provider:
|
||||
await _github_provider.close()
|
||||
|
||||
# Shutdown thread pool executor for git operations
|
||||
from git_wrapper import _executor
|
||||
|
||||
if _executor:
|
||||
logger.info("Shutting down git operations thread pool...")
|
||||
_executor.shutdown(wait=True)
|
||||
|
||||
logger.info("Git Operations MCP Server shut down")
|
||||
|
||||
|
||||
@@ -479,7 +530,9 @@ async def clone_repository(
|
||||
project_id: str = Field(..., description="Project ID for scoping"),
|
||||
agent_id: str = Field(..., description="Agent ID making the request"),
|
||||
repo_url: str = Field(..., description="Repository URL to clone"),
|
||||
branch: str | None = Field(default=None, description="Branch to checkout after clone"),
|
||||
branch: str | None = Field(
|
||||
default=None, description="Branch to checkout after clone"
|
||||
),
|
||||
depth: int | None = Field(default=None, ge=1, description="Shallow clone depth"),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
@@ -490,11 +543,23 @@ async def clone_repository(
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_url(repo_url):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
# Create workspace
|
||||
workspace = await _workspace_manager.create_workspace(project_id, repo_url) # type: ignore[union-attr]
|
||||
@@ -504,7 +569,9 @@ async def clone_repository(
|
||||
|
||||
# Clone repository
|
||||
git = GitWrapper(workspace.path, _settings)
|
||||
result = await git.clone(repo_url, branch=branch, depth=depth, auth_token=auth_token)
|
||||
result = await git.clone(
|
||||
repo_url, branch=branch, depth=depth, auth_token=auth_token
|
||||
)
|
||||
|
||||
# Update workspace metadata
|
||||
await _workspace_manager.update_workspace_branch(project_id, result.branch) # type: ignore[union-attr]
|
||||
@@ -522,14 +589,20 @@ async def clone_repository(
|
||||
return {"success": False, "error": e.message, "code": e.code.value}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected clone error: {e}")
|
||||
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def git_status(
|
||||
project_id: str = Field(..., description="Project ID for scoping"),
|
||||
agent_id: str = Field(..., description="Agent ID making the request"),
|
||||
include_untracked: bool = Field(default=True, description="Include untracked files"),
|
||||
include_untracked: bool = Field(
|
||||
default=True, description="Include untracked files"
|
||||
),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get git status for a project workspace.
|
||||
@@ -538,13 +611,25 @@ async def git_status(
|
||||
"""
|
||||
try:
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
|
||||
if not workspace:
|
||||
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Workspace not found for project: {project_id}",
|
||||
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
|
||||
}
|
||||
|
||||
git = GitWrapper(workspace.path, _settings)
|
||||
result = await git.status(include_untracked=include_untracked)
|
||||
@@ -567,7 +652,11 @@ async def git_status(
|
||||
return {"success": False, "error": e.message, "code": e.code.value}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected status error: {e}")
|
||||
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
# MCP Tools - Branch Operations
|
||||
@@ -578,7 +667,9 @@ async def create_branch(
|
||||
project_id: str = Field(..., description="Project ID for scoping"),
|
||||
agent_id: str = Field(..., description="Agent ID making the request"),
|
||||
branch_name: str = Field(..., description="Name for the new branch"),
|
||||
from_ref: str | None = Field(default=None, description="Reference to create from (default: HEAD)"),
|
||||
from_ref: str | None = Field(
|
||||
default=None, description="Reference to create from (default: HEAD)"
|
||||
),
|
||||
checkout: bool = Field(default=True, description="Checkout after creation"),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
@@ -586,18 +677,36 @@ async def create_branch(
|
||||
"""
|
||||
try:
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_branch(branch_name):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
|
||||
if not workspace:
|
||||
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Workspace not found for project: {project_id}",
|
||||
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
|
||||
}
|
||||
|
||||
git = GitWrapper(workspace.path, _settings)
|
||||
result = await git.create_branch(branch_name, from_ref=from_ref, checkout=checkout)
|
||||
result = await git.create_branch(
|
||||
branch_name, from_ref=from_ref, checkout=checkout
|
||||
)
|
||||
|
||||
if checkout:
|
||||
await _workspace_manager.update_workspace_branch(project_id, branch_name) # type: ignore[union-attr]
|
||||
@@ -614,7 +723,11 @@ async def create_branch(
|
||||
return {"success": False, "error": e.message, "code": e.code.value}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected create branch error: {e}")
|
||||
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@@ -628,13 +741,25 @@ async def list_branches(
|
||||
"""
|
||||
try:
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
|
||||
if not workspace:
|
||||
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Workspace not found for project: {project_id}",
|
||||
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
|
||||
}
|
||||
|
||||
git = GitWrapper(workspace.path, _settings)
|
||||
result = await git.list_branches(include_remote=include_remote)
|
||||
@@ -652,7 +777,11 @@ async def list_branches(
|
||||
return {"success": False, "error": e.message, "code": e.code.value}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected list branches error: {e}")
|
||||
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@@ -668,13 +797,25 @@ async def checkout(
|
||||
"""
|
||||
try:
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
|
||||
if not workspace:
|
||||
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Workspace not found for project: {project_id}",
|
||||
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
|
||||
}
|
||||
|
||||
git = GitWrapper(workspace.path, _settings)
|
||||
result = await git.checkout(ref, create_branch=create_branch, force=force)
|
||||
@@ -692,7 +833,11 @@ async def checkout(
|
||||
return {"success": False, "error": e.message, "code": e.code.value}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected checkout error: {e}")
|
||||
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
# MCP Tools - Commit Operations
|
||||
@@ -703,7 +848,9 @@ async def commit(
|
||||
project_id: str = Field(..., description="Project ID for scoping"),
|
||||
agent_id: str = Field(..., description="Agent ID making the request"),
|
||||
message: str = Field(..., description="Commit message"),
|
||||
files: list[str] | None = Field(default=None, description="Specific files to commit (None = all staged)"),
|
||||
files: list[str] | None = Field(
|
||||
default=None, description="Specific files to commit (None = all staged)"
|
||||
),
|
||||
author_name: str | None = Field(default=None, description="Author name override"),
|
||||
author_email: str | None = Field(default=None, description="Author email override"),
|
||||
) -> dict[str, Any]:
|
||||
@@ -714,13 +861,25 @@ async def commit(
|
||||
"""
|
||||
try:
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
|
||||
if not workspace:
|
||||
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Workspace not found for project: {project_id}",
|
||||
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
|
||||
}
|
||||
|
||||
git = GitWrapper(workspace.path, _settings)
|
||||
result = await git.commit(
|
||||
@@ -745,14 +904,20 @@ async def commit(
|
||||
return {"success": False, "error": e.message, "code": e.code.value}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected commit error: {e}")
|
||||
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def push(
|
||||
project_id: str = Field(..., description="Project ID for scoping"),
|
||||
agent_id: str = Field(..., description="Agent ID making the request"),
|
||||
branch: str | None = Field(default=None, description="Branch to push (None = current)"),
|
||||
branch: str | None = Field(
|
||||
default=None, description="Branch to push (None = current)"
|
||||
),
|
||||
remote: str = Field(default="origin", description="Remote name"),
|
||||
force: bool = Field(default=False, description="Force push"),
|
||||
set_upstream: bool = Field(default=True, description="Set upstream tracking"),
|
||||
@@ -762,13 +927,25 @@ async def push(
|
||||
"""
|
||||
try:
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
|
||||
if not workspace:
|
||||
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Workspace not found for project: {project_id}",
|
||||
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
|
||||
}
|
||||
|
||||
# Get auth token appropriate for the repository URL
|
||||
auth_token = _get_auth_token_for_url(workspace.repo_url or "")
|
||||
@@ -794,14 +971,20 @@ async def push(
|
||||
return {"success": False, "error": e.message, "code": e.code.value}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected push error: {e}")
|
||||
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def pull(
|
||||
project_id: str = Field(..., description="Project ID for scoping"),
|
||||
agent_id: str = Field(..., description="Agent ID making the request"),
|
||||
branch: str | None = Field(default=None, description="Branch to pull (None = current)"),
|
||||
branch: str | None = Field(
|
||||
default=None, description="Branch to pull (None = current)"
|
||||
),
|
||||
remote: str = Field(default="origin", description="Remote name"),
|
||||
rebase: bool = Field(default=False, description="Rebase instead of merge"),
|
||||
) -> dict[str, Any]:
|
||||
@@ -810,13 +993,25 @@ async def pull(
|
||||
"""
|
||||
try:
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
|
||||
if not workspace:
|
||||
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Workspace not found for project: {project_id}",
|
||||
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
|
||||
}
|
||||
|
||||
git = GitWrapper(workspace.path, _settings)
|
||||
result = await git.pull(branch=branch, remote=remote, rebase=rebase)
|
||||
@@ -834,7 +1029,11 @@ async def pull(
|
||||
return {"success": False, "error": e.message, "code": e.code.value}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected pull error: {e}")
|
||||
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
# MCP Tools - Diff and Log
|
||||
@@ -854,16 +1053,30 @@ async def diff(
|
||||
"""
|
||||
try:
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
|
||||
if not workspace:
|
||||
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Workspace not found for project: {project_id}",
|
||||
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
|
||||
}
|
||||
|
||||
git = GitWrapper(workspace.path, _settings)
|
||||
result = await git.diff(base=base, head=head, files=files, context_lines=context_lines)
|
||||
result = await git.diff(
|
||||
base=base, head=head, files=files, context_lines=context_lines
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -881,7 +1094,11 @@ async def diff(
|
||||
return {"success": False, "error": e.message, "code": e.code.value}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected diff error: {e}")
|
||||
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@@ -898,13 +1115,25 @@ async def log(
|
||||
"""
|
||||
try:
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
|
||||
if not workspace:
|
||||
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Workspace not found for project: {project_id}",
|
||||
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
|
||||
}
|
||||
|
||||
git = GitWrapper(workspace.path, _settings)
|
||||
result = await git.log(ref=ref, limit=limit, skip=skip, path=path)
|
||||
@@ -921,7 +1150,11 @@ async def log(
|
||||
return {"success": False, "error": e.message, "code": e.code.value}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected log error: {e}")
|
||||
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
# MCP Tools - Pull Request Operations
|
||||
@@ -938,27 +1171,49 @@ async def create_pull_request(
|
||||
draft: bool = Field(default=False, description="Create as draft"),
|
||||
labels: list[str] | None = Field(default=None, description="Labels to add"),
|
||||
assignees: list[str] | None = Field(default=None, description="Users to assign"),
|
||||
reviewers: list[str] | None = Field(default=None, description="Users to request review from"),
|
||||
reviewers: list[str] | None = Field(
|
||||
default=None, description="Users to request review from"
|
||||
),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create a pull request on the remote provider.
|
||||
"""
|
||||
try:
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
|
||||
if not workspace:
|
||||
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Workspace not found for project: {project_id}",
|
||||
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
|
||||
}
|
||||
|
||||
if not workspace.repo_url:
|
||||
return {"success": False, "error": "Workspace has no repository URL", "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Workspace has no repository URL",
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
provider = _get_provider_for_url(workspace.repo_url)
|
||||
if not provider:
|
||||
return {"success": False, "error": "No provider configured for this repository", "code": ErrorCode.PROVIDER_NOT_FOUND.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": "No provider configured for this repository",
|
||||
"code": ErrorCode.PROVIDER_NOT_FOUND.value,
|
||||
}
|
||||
|
||||
owner, repo = provider.parse_repo_url(workspace.repo_url)
|
||||
|
||||
@@ -987,7 +1242,11 @@ async def create_pull_request(
|
||||
return {"success": False, "error": e.message, "code": e.code.value}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected create PR error: {e}")
|
||||
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@@ -1001,20 +1260,40 @@ async def get_pull_request(
|
||||
"""
|
||||
try:
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
|
||||
if not workspace:
|
||||
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Workspace not found for project: {project_id}",
|
||||
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
|
||||
}
|
||||
|
||||
if not workspace.repo_url:
|
||||
return {"success": False, "error": "Workspace has no repository URL", "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Workspace has no repository URL",
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
provider = _get_provider_for_url(workspace.repo_url)
|
||||
if not provider:
|
||||
return {"success": False, "error": "No provider configured for this repository", "code": ErrorCode.PROVIDER_NOT_FOUND.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": "No provider configured for this repository",
|
||||
"code": ErrorCode.PROVIDER_NOT_FOUND.value,
|
||||
}
|
||||
|
||||
owner, repo = provider.parse_repo_url(workspace.repo_url)
|
||||
result = await provider.get_pr(owner, repo, pr_number)
|
||||
@@ -1030,14 +1309,20 @@ async def get_pull_request(
|
||||
return {"success": False, "error": e.message, "code": e.code.value}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected get PR error: {e}")
|
||||
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def list_pull_requests(
|
||||
project_id: str = Field(..., description="Project ID for scoping"),
|
||||
agent_id: str = Field(..., description="Agent ID making the request"),
|
||||
state: str | None = Field(default=None, description="Filter by state: open, closed, merged"),
|
||||
state: str | None = Field(
|
||||
default=None, description="Filter by state: open, closed, merged"
|
||||
),
|
||||
author: str | None = Field(default=None, description="Filter by author"),
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Max PRs"),
|
||||
) -> dict[str, Any]:
|
||||
@@ -1046,20 +1331,40 @@ async def list_pull_requests(
|
||||
"""
|
||||
try:
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
|
||||
if not workspace:
|
||||
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Workspace not found for project: {project_id}",
|
||||
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
|
||||
}
|
||||
|
||||
if not workspace.repo_url:
|
||||
return {"success": False, "error": "Workspace has no repository URL", "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Workspace has no repository URL",
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
provider = _get_provider_for_url(workspace.repo_url)
|
||||
if not provider:
|
||||
return {"success": False, "error": "No provider configured for this repository", "code": ErrorCode.PROVIDER_NOT_FOUND.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": "No provider configured for this repository",
|
||||
"code": ErrorCode.PROVIDER_NOT_FOUND.value,
|
||||
}
|
||||
|
||||
# Parse state
|
||||
pr_state = None
|
||||
@@ -1067,10 +1372,16 @@ async def list_pull_requests(
|
||||
try:
|
||||
pr_state = PRState(state.lower())
|
||||
except ValueError:
|
||||
return {"success": False, "error": f"Invalid state: {state}. Valid: open, closed, merged", "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid state: {state}. Valid: open, closed, merged",
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
owner, repo = provider.parse_repo_url(workspace.repo_url)
|
||||
result = await provider.list_prs(owner, repo, state=pr_state, author=author, limit=limit)
|
||||
result = await provider.list_prs(
|
||||
owner, repo, state=pr_state, author=author, limit=limit
|
||||
)
|
||||
|
||||
return {
|
||||
"success": result.success,
|
||||
@@ -1084,7 +1395,11 @@ async def list_pull_requests(
|
||||
return {"success": False, "error": e.message, "code": e.code.value}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected list PRs error: {e}")
|
||||
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@@ -1092,39 +1407,71 @@ async def merge_pull_request(
|
||||
project_id: str = Field(..., description="Project ID for scoping"),
|
||||
agent_id: str = Field(..., description="Agent ID making the request"),
|
||||
pr_number: int = Field(..., description="Pull request number"),
|
||||
merge_strategy: str = Field(default="merge", description="Strategy: merge, squash, rebase"),
|
||||
commit_message: str | None = Field(default=None, description="Custom commit message"),
|
||||
delete_branch: bool = Field(default=True, description="Delete source branch after merge"),
|
||||
merge_strategy: str = Field(
|
||||
default="merge", description="Strategy: merge, squash, rebase"
|
||||
),
|
||||
commit_message: str | None = Field(
|
||||
default=None, description="Custom commit message"
|
||||
),
|
||||
delete_branch: bool = Field(
|
||||
default=True, description="Delete source branch after merge"
|
||||
),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Merge a pull request.
|
||||
"""
|
||||
try:
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
|
||||
if not workspace:
|
||||
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Workspace not found for project: {project_id}",
|
||||
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
|
||||
}
|
||||
|
||||
if not workspace.repo_url:
|
||||
return {"success": False, "error": "Workspace has no repository URL", "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Workspace has no repository URL",
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
provider = _get_provider_for_url(workspace.repo_url)
|
||||
if not provider:
|
||||
return {"success": False, "error": "No provider configured for this repository", "code": ErrorCode.PROVIDER_NOT_FOUND.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": "No provider configured for this repository",
|
||||
"code": ErrorCode.PROVIDER_NOT_FOUND.value,
|
||||
}
|
||||
|
||||
# Parse merge strategy
|
||||
try:
|
||||
strategy = MergeStrategy(merge_strategy.lower())
|
||||
except ValueError:
|
||||
return {"success": False, "error": f"Invalid strategy: {merge_strategy}. Valid: merge, squash, rebase", "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid strategy: {merge_strategy}. Valid: merge, squash, rebase",
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
owner, repo = provider.parse_repo_url(workspace.repo_url)
|
||||
result = await provider.merge_pr(
|
||||
owner, repo, pr_number,
|
||||
owner,
|
||||
repo,
|
||||
pr_number,
|
||||
merge_strategy=strategy,
|
||||
commit_message=commit_message,
|
||||
delete_branch=delete_branch,
|
||||
@@ -1142,7 +1489,11 @@ async def merge_pull_request(
|
||||
return {"success": False, "error": e.message, "code": e.code.value}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected merge PR error: {e}")
|
||||
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
# MCP Tools - Workspace Operations
|
||||
@@ -1158,13 +1509,25 @@ async def get_workspace(
|
||||
"""
|
||||
try:
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
|
||||
if not workspace:
|
||||
return {"success": False, "error": f"Workspace not found for project: {project_id}", "code": ErrorCode.WORKSPACE_NOT_FOUND.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Workspace not found for project: {project_id}",
|
||||
"code": ErrorCode.WORKSPACE_NOT_FOUND.value,
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -1176,14 +1539,20 @@ async def get_workspace(
|
||||
return {"success": False, "error": e.message, "code": e.code.value}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected get workspace error: {e}")
|
||||
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def lock_workspace(
|
||||
project_id: str = Field(..., description="Project ID"),
|
||||
agent_id: str = Field(..., description="Agent ID requesting lock"),
|
||||
timeout: int = Field(default=300, ge=10, le=3600, description="Lock timeout in seconds"),
|
||||
timeout: int = Field(
|
||||
default=300, ge=10, le=3600, description="Lock timeout in seconds"
|
||||
),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Acquire a lock on a workspace.
|
||||
@@ -1192,9 +1561,17 @@ async def lock_workspace(
|
||||
"""
|
||||
try:
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
success = await _workspace_manager.lock_workspace(project_id, agent_id, timeout) # type: ignore[union-attr]
|
||||
workspace = await _workspace_manager.get_workspace(project_id) # type: ignore[union-attr]
|
||||
@@ -1202,7 +1579,9 @@ async def lock_workspace(
|
||||
return {
|
||||
"success": success,
|
||||
"lock_holder": workspace.lock_holder if workspace else None,
|
||||
"lock_expires": workspace.lock_expires.isoformat() if workspace and workspace.lock_expires else None,
|
||||
"lock_expires": workspace.lock_expires.isoformat()
|
||||
if workspace and workspace.lock_expires
|
||||
else None,
|
||||
}
|
||||
|
||||
except GitOpsError as e:
|
||||
@@ -1210,7 +1589,11 @@ async def lock_workspace(
|
||||
return {"success": False, "error": e.message, "code": e.code.value}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected lock workspace error: {e}")
|
||||
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@@ -1224,9 +1607,17 @@ async def unlock_workspace(
|
||||
"""
|
||||
try:
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
success = await _workspace_manager.unlock_workspace(project_id, agent_id, force) # type: ignore[union-attr]
|
||||
|
||||
@@ -1237,7 +1628,11 @@ async def unlock_workspace(
|
||||
return {"success": False, "error": e.message, "code": e.code.value}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected unlock workspace error: {e}")
|
||||
return {"success": False, "error": str(e), "code": ErrorCode.INTERNAL_ERROR.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": ErrorCode.INTERNAL_ERROR.value,
|
||||
}
|
||||
|
||||
|
||||
# Register all tools
|
||||
|
||||
@@ -13,7 +13,7 @@ from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiofiles
|
||||
import aiofiles # type: ignore[import-untyped]
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
from config import Settings, get_settings
|
||||
@@ -54,10 +54,25 @@ class WorkspaceManager:
|
||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_workspace_path(self, project_id: str) -> Path:
|
||||
"""Get the path for a project workspace."""
|
||||
"""Get the path for a project workspace with path traversal protection."""
|
||||
# Sanitize project ID for filesystem
|
||||
safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_id)
|
||||
return self.base_path / safe_id
|
||||
|
||||
# Reject reserved names
|
||||
reserved_names = {".", "..", "con", "prn", "aux", "nul"}
|
||||
if safe_id.lower() in reserved_names:
|
||||
raise ValueError(f"Invalid project ID: reserved name '{project_id}'")
|
||||
|
||||
# Construct path and verify it's within base_path (prevent path traversal)
|
||||
workspace_path = (self.base_path / safe_id).resolve()
|
||||
base_resolved = self.base_path.resolve()
|
||||
|
||||
if not workspace_path.is_relative_to(base_resolved):
|
||||
raise ValueError(
|
||||
f"Invalid project ID: path traversal detected '{project_id}'"
|
||||
)
|
||||
|
||||
return workspace_path
|
||||
|
||||
def _get_lock_path(self, project_id: str) -> Path:
|
||||
"""Get the lock file path for a workspace."""
|
||||
@@ -230,10 +245,7 @@ class WorkspaceManager:
|
||||
raise WorkspaceNotFoundError(project_id)
|
||||
|
||||
# Check if already locked by someone else
|
||||
if (
|
||||
workspace.state == WorkspaceState.LOCKED
|
||||
and workspace.lock_holder != holder
|
||||
):
|
||||
if workspace.state == WorkspaceState.LOCKED and workspace.lock_holder != holder:
|
||||
# Check if lock expired
|
||||
if workspace.lock_expires and workspace.lock_expires > datetime.now(UTC):
|
||||
raise WorkspaceLockedError(project_id, workspace.lock_holder)
|
||||
@@ -275,11 +287,7 @@ class WorkspaceManager:
|
||||
raise WorkspaceNotFoundError(project_id)
|
||||
|
||||
# Verify holder
|
||||
if (
|
||||
not force
|
||||
and workspace.lock_holder
|
||||
and workspace.lock_holder != holder
|
||||
):
|
||||
if not force and workspace.lock_holder and workspace.lock_holder != holder:
|
||||
raise WorkspaceLockedError(project_id, workspace.lock_holder)
|
||||
|
||||
# Clear lock
|
||||
@@ -341,7 +349,7 @@ class WorkspaceManager:
|
||||
return True
|
||||
|
||||
size_bytes = await self._calculate_size(workspace_path)
|
||||
size_gb = size_bytes / (1024 ** 3)
|
||||
size_gb = size_bytes / (1024**3)
|
||||
max_size_gb = self.settings.workspace_max_size_gb
|
||||
|
||||
if size_gb > max_size_gb:
|
||||
@@ -362,7 +370,7 @@ class WorkspaceManager:
|
||||
Returns:
|
||||
List of WorkspaceInfo
|
||||
"""
|
||||
workspaces = []
|
||||
workspaces: list[WorkspaceInfo] = []
|
||||
|
||||
if not self.base_path.exists():
|
||||
return workspaces
|
||||
@@ -532,9 +540,7 @@ class WorkspaceLock:
|
||||
self.holder,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to release lock for {self.project_id}: {e}"
|
||||
)
|
||||
logger.warning(f"Failed to release lock for {self.project_id}: {e}")
|
||||
|
||||
|
||||
class FileLockManager:
|
||||
|
||||
Reference in New Issue
Block a user