Files
AegisGitea-MCP/src/aegis_gitea_mcp/tools/arguments.py
T
Latte 2d7f12d0d0 feat: safe full-API coverage via classified gitea_request dispatch
Add a deterministic (method, path) read/write classifier with an explicit
render-only override table that can only downgrade provably side-effect-free
POSTs (markdown/markup) to reads, never the reverse — so a mutating call cannot
slip past the write-mode gate. Add a known-Gitea-prefix gate: gitea_request now
fails closed on any path whose top segment is not a recognized /api/v1 route
instead of passing unknown paths through. Expose raw_relative_segments for the
authorization layer.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-27 11:09:30 +02:00

690 lines
23 KiB
Python

"""Pydantic argument models for MCP tools."""
from __future__ import annotations
import re
from typing import Annotated, Any, Literal
from urllib.parse import urlsplit
from pydantic import (
AfterValidator,
BaseModel,
BeforeValidator,
ConfigDict,
Field,
field_validator,
model_validator,
)
_REPO_PART_PATTERN = r"^[A-Za-z0-9._-]{1,100}$"
def _validate_git_ref(value: str) -> str:
"""Validate a git ref-like value (ref/sha/base/head) against traversal.
Refs that legitimately contain ``/`` (e.g. ``feature/foo``, ``release/1.0``)
are preserved; only traversal and unsafe URL-path characters are rejected.
Args:
value: Candidate ref, sha, base, or head value.
Returns:
The unchanged value when it is safe.
Raises:
ValueError: When the value could escape the intended repository path.
"""
# Security decision: block path traversal and absolute references.
if ".." in value.split("/"):
raise ValueError("ref must not contain '..' path segments")
if value.startswith("/"):
raise ValueError("ref must not start with '/'")
if "\\" in value:
raise ValueError("ref must not contain backslashes")
if "\x00" in value:
raise ValueError("ref must not contain null bytes")
if any(ord(char) < 0x20 or ord(char) == 0x7F for char in value):
raise ValueError("ref must not contain control characters")
if any(char.isspace() for char in value):
raise ValueError("ref must not contain whitespace")
if "?" in value or "#" in value:
raise ValueError("ref must not contain '?' or '#'")
return value
GitRef = Annotated[str, AfterValidator(_validate_git_ref)]
def _validate_milestone(value: object) -> int | str:
"""Validate a milestone reference supplied as a numeric id or a title.
An integer is treated as a milestone id (``0`` clears the milestone on
update); a string is treated as a milestone title to resolve. Runs as a
``BeforeValidator`` so ``bool`` (a subclass of ``int`` that Pydantic would
otherwise coerce to ``1``/``0``) is rejected on the raw input.
"""
if isinstance(value, bool):
raise ValueError("milestone must be a milestone id or title")
if isinstance(value, int):
if value < 0:
raise ValueError("milestone id must be >= 0")
return value
if isinstance(value, str):
title = value.strip()
if not title:
raise ValueError("milestone title must not be empty")
if len(title) > 256:
raise ValueError("milestone title must not exceed 256 characters")
return title
raise ValueError("milestone must be a milestone id or title")
MilestoneRef = Annotated[int | str, BeforeValidator(_validate_milestone)]
class StrictBaseModel(BaseModel):
"""Strict model base that rejects unexpected fields."""
model_config = ConfigDict(extra="forbid")
class ListRepositoriesArgs(StrictBaseModel):
"""Arguments for list_repositories tool."""
class RepositoryArgs(StrictBaseModel):
"""Common repository locator arguments."""
owner: str = Field(..., pattern=_REPO_PART_PATTERN)
repo: str = Field(..., pattern=_REPO_PART_PATTERN)
class FileTreeArgs(RepositoryArgs):
"""Arguments for get_file_tree."""
ref: GitRef = Field(default="main", min_length=1, max_length=200)
recursive: bool = Field(default=False)
class FileContentsArgs(RepositoryArgs):
"""Arguments for get_file_contents."""
filepath: str = Field(..., min_length=1, max_length=1024)
ref: GitRef = Field(default="main", min_length=1, max_length=200)
@model_validator(mode="after")
def validate_filepath(self) -> FileContentsArgs:
"""Validate path safety constraints."""
normalized = self.filepath.replace("\\", "/")
# Security decision: block traversal and absolute paths.
if normalized.startswith("/") or ".." in normalized.split("/"):
raise ValueError("filepath must be a relative path without traversal")
if "\x00" in normalized:
raise ValueError("filepath cannot contain null bytes")
return self
class SearchCodeArgs(RepositoryArgs):
"""Arguments for search_code."""
query: str = Field(..., min_length=1, max_length=256)
ref: GitRef = Field(default="main", min_length=1, max_length=200)
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=25, ge=1, le=100)
class ListCommitsArgs(RepositoryArgs):
"""Arguments for list_commits."""
ref: GitRef = Field(default="main", min_length=1, max_length=200)
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=25, ge=1, le=100)
class CommitDiffArgs(RepositoryArgs):
"""Arguments for get_commit_diff."""
sha: GitRef = Field(..., min_length=7, max_length=64)
class CompareRefsArgs(RepositoryArgs):
"""Arguments for compare_refs."""
base: GitRef = Field(..., min_length=1, max_length=200)
head: GitRef = Field(..., min_length=1, max_length=200)
class ListIssuesArgs(RepositoryArgs):
"""Arguments for list_issues."""
state: Literal["open", "closed", "all"] = Field(default="open")
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=25, ge=1, le=100)
labels: list[str] = Field(default_factory=list, max_length=20)
class IssueArgs(RepositoryArgs):
"""Arguments for get_issue."""
issue_number: int = Field(..., ge=1)
class ListPullRequestsArgs(RepositoryArgs):
"""Arguments for list_pull_requests."""
state: Literal["open", "closed", "all"] = Field(default="open")
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=25, ge=1, le=100)
class PullRequestArgs(RepositoryArgs):
"""Arguments for get_pull_request."""
pull_number: int = Field(..., ge=1)
class ListLabelsArgs(RepositoryArgs):
"""Arguments for list_labels."""
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class ListTagsArgs(RepositoryArgs):
"""Arguments for list_tags."""
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class ListReleasesArgs(RepositoryArgs):
"""Arguments for list_releases."""
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=25, ge=1, le=100)
class CreateIssueArgs(RepositoryArgs):
"""Arguments for create_issue."""
title: str = Field(..., min_length=1, max_length=256)
body: str = Field(default="", max_length=20_000)
labels: list[str] = Field(default_factory=list, max_length=20)
assignees: list[str] = Field(default_factory=list, max_length=20)
milestone: MilestoneRef | None = Field(
default=None, description="Milestone id or title to assign the issue to"
)
class UpdateIssueArgs(RepositoryArgs):
"""Arguments for update_issue."""
issue_number: int = Field(..., ge=1)
title: str | None = Field(default=None, min_length=1, max_length=256)
body: str | None = Field(default=None, max_length=20_000)
state: Literal["open", "closed"] | None = Field(default=None)
milestone: MilestoneRef | None = Field(
default=None, description="Milestone id or title to assign; 0 clears the milestone"
)
@model_validator(mode="after")
def require_change(self) -> UpdateIssueArgs:
"""Require at least one mutable field in update payload."""
if (
self.title is None
and self.body is None
and self.state is None
and self.milestone is None
):
raise ValueError("At least one of title, body, state, or milestone must be provided")
return self
class CreateIssueCommentArgs(RepositoryArgs):
"""Arguments for create_issue_comment."""
issue_number: int = Field(..., ge=1)
body: str = Field(..., min_length=1, max_length=10_000)
class CreatePrCommentArgs(RepositoryArgs):
"""Arguments for create_pr_comment."""
pull_number: int = Field(..., ge=1)
body: str = Field(..., min_length=1, max_length=10_000)
class AddLabelsArgs(RepositoryArgs):
"""Arguments for add_labels."""
issue_number: int = Field(..., ge=1)
labels: list[str] = Field(..., min_length=1, max_length=20)
class AssignIssueArgs(RepositoryArgs):
"""Arguments for assign_issue."""
issue_number: int = Field(..., ge=1)
assignees: list[str] = Field(..., min_length=1, max_length=20)
class CreateLabelArgs(RepositoryArgs):
"""Arguments for create_label."""
name: str = Field(..., min_length=1, max_length=50)
# Gitea requires a hex color; accept it with or without a leading '#'.
color: str = Field(..., pattern=r"^#?[0-9A-Fa-f]{6}$")
description: str = Field(default="", max_length=1000)
exclusive: bool = Field(default=False)
class UpdateLabelArgs(RepositoryArgs):
"""Arguments for update_label (located by current name)."""
name: str = Field(..., min_length=1, max_length=50)
new_name: str | None = Field(default=None, min_length=1, max_length=50)
color: str | None = Field(default=None, pattern=r"^#?[0-9A-Fa-f]{6}$")
description: str | None = Field(default=None, max_length=1000)
@model_validator(mode="after")
def require_change(self) -> UpdateLabelArgs:
"""Require at least one mutable field in the update payload."""
if self.new_name is None and self.color is None and self.description is None:
raise ValueError("At least one of new_name, color, or description must be provided")
return self
class RemoveLabelsArgs(RepositoryArgs):
"""Arguments for remove_labels."""
issue_number: int = Field(..., ge=1)
labels: list[str] = Field(..., min_length=1, max_length=20)
class CreatePullRequestArgs(RepositoryArgs):
"""Arguments for create_pull_request."""
title: str = Field(..., min_length=1, max_length=256)
head: GitRef = Field(..., min_length=1, max_length=200)
base: GitRef = Field(..., min_length=1, max_length=200)
body: str = Field(default="", max_length=20_000)
class CreateReleaseArgs(RepositoryArgs):
"""Arguments for create_release."""
tag_name: GitRef = Field(..., min_length=1, max_length=200)
name: str = Field(default="", max_length=256)
body: str = Field(default="", max_length=20_000)
draft: bool = Field(default=False)
prerelease: bool = Field(default=False)
target: str | None = Field(default=None, min_length=1, max_length=200)
class EditReleaseArgs(RepositoryArgs):
"""Arguments for edit_release."""
release_id: int = Field(..., ge=1)
name: str | None = Field(default=None, max_length=256)
body: str | None = Field(default=None, max_length=20_000)
draft: bool | None = Field(default=None)
prerelease: bool | None = Field(default=None)
@model_validator(mode="after")
def require_change(self) -> EditReleaseArgs:
"""Require at least one mutable field in the update payload."""
if (
self.name is None
and self.body is None
and self.draft is None
and self.prerelease is None
):
raise ValueError("At least one of name, body, draft, or prerelease must be provided")
return self
class CreateBranchArgs(RepositoryArgs):
"""Arguments for create_branch."""
new_branch_name: GitRef = Field(..., min_length=1, max_length=200)
old_branch_name: str | None = Field(default=None, min_length=1, max_length=200)
class CreateMilestoneArgs(RepositoryArgs):
"""Arguments for create_milestone."""
title: str = Field(..., min_length=1, max_length=256)
description: str = Field(default="", max_length=10_000)
due_on: str | None = Field(default=None, max_length=64)
class EditIssueCommentArgs(RepositoryArgs):
"""Arguments for edit_issue_comment."""
comment_id: int = Field(..., ge=1)
body: str = Field(..., min_length=1, max_length=10_000)
class ListPullRequestFilesArgs(RepositoryArgs):
"""Arguments for list_pull_request_files."""
pull_number: int = Field(..., ge=1)
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class ListPullRequestCommitsArgs(RepositoryArgs):
"""Arguments for list_pull_request_commits."""
pull_number: int = Field(..., ge=1)
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class ListIssueCommentsArgs(RepositoryArgs):
"""Arguments for list_issue_comments."""
issue_number: int = Field(..., ge=1)
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class ListBranchesArgs(RepositoryArgs):
"""Arguments for list_branches."""
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class GetBranchArgs(RepositoryArgs):
"""Arguments for get_branch."""
branch: GitRef = Field(..., min_length=1, max_length=200)
class GetReleaseArgs(RepositoryArgs):
"""Arguments for get_release."""
release_id: int = Field(..., ge=1)
class LatestReleaseArgs(RepositoryArgs):
"""Arguments for get_latest_release."""
class ListMilestonesArgs(RepositoryArgs):
"""Arguments for list_milestones."""
state: Literal["open", "closed", "all"] = Field(default="open")
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class CommitStatusArgs(RepositoryArgs):
"""Arguments for get_commit_status."""
sha: GitRef = Field(..., min_length=1, max_length=64)
class ListOrgRepositoriesArgs(StrictBaseModel):
"""Arguments for list_org_repositories."""
org: str = Field(..., pattern=_REPO_PART_PATTERN)
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class ListOrganizationsArgs(StrictBaseModel):
"""Arguments for list_organizations."""
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class RepoLanguagesArgs(RepositoryArgs):
"""Arguments for get_repo_languages."""
class RepoTopicsArgs(RepositoryArgs):
"""Arguments for list_repo_topics."""
# --- Raw API dispatch (gitea_request escape hatch) -------------------------
# HTTP methods the generic dispatch tool accepts. Everything outside GET/HEAD is
# treated as a write so the policy/write-mode gate applies.
RAW_API_METHODS = ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE")
_RAW_WRITE_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"})
# Path segments/subpaths blocked for *every* method unless explicitly overridden
# via RAW_API_ALLOW_SENSITIVE. A GET on these already leaks credentials or
# privileged configuration, so they are denied independently of policy.yaml.
_RAW_SENSITIVE_SEGMENTS = frozenset({"admin", "tokens", "secrets", "hooks", "keys", "gpg_keys"})
_RAW_SENSITIVE_SUBPATHS = ("applications/oauth2", "actions/runners/registration-token")
# Endpoints under /repos/ that are not scoped to a single repository.
_RAW_CROSS_REPO_OWNERS = frozenset({"search", "issues"})
# Resources whose trailing segments form a file path target for policy checks.
_RAW_FILE_RESOURCES = frozenset({"contents", "raw", "media"})
# Known top-level segments of the Gitea ``/api/v1`` surface. A raw request whose
# first path segment is not in this set is rejected (fail closed): we never pass
# an unrecognized path straight through to Gitea.
KNOWN_API_PREFIXES = frozenset(
{
"activitypub",
"admin",
"gitignore",
"issues",
"label",
"licenses",
"markdown",
"markup",
"miscellaneous",
"nodeinfo",
"notifications",
"org",
"orgs",
"packages",
"repos",
"repositories",
"settings",
"signing-key.gpg",
"teams",
"topics",
"user",
"users",
"version",
}
)
# Override table: provably side-effect-free POSTs that may be treated as reads so
# they do not needlessly require WRITE_MODE. This table may ONLY ever DOWNGRADE a
# write to a read for endpoints that render content and mutate nothing — never
# the reverse. Keyed by the final path segment of the endpoint.
_RAW_READ_ONLY_POST_LEAVES = frozenset({"markdown", "markup", "raw"})
def raw_is_known_api_path(endpoint: str) -> bool:
"""Return whether the endpoint's top segment is a known Gitea API prefix."""
return raw_top_segment(endpoint) in KNOWN_API_PREFIXES
def raw_request_is_write(method: str, endpoint: str) -> bool:
"""Classify a raw request as read or write from its method and path.
``GET``/``HEAD`` are reads; every other method is a write — except for the
small, explicit override table of render-only POSTs (e.g. markdown/markup),
which are reads. The override can only make a request *more* permissive for
provably side-effect-free endpoints; it never reclassifies a mutating call as
a read, so a misclassified write cannot slip past the write-mode gate.
"""
upper = method.upper()
if upper in {"GET", "HEAD"}:
return False
if upper == "POST":
rel = _raw_relative_segments(endpoint)
if rel and rel[-1] in _RAW_READ_ONLY_POST_LEAVES:
return False
return True
def normalize_raw_endpoint(path: str) -> str:
"""Normalize a raw API path into an ``/api/v1``-prefixed endpoint.
Accepts a bare path (``/repos/o/r``), an already-prefixed path
(``/api/v1/repos/o/r``), or a full URL (the scheme/host and any query string
are stripped — the separate ``query`` argument carries query parameters).
Raises:
ValueError: When the path contains a ``..`` traversal segment.
"""
candidate = path.strip()
split = urlsplit(candidate)
# When a full URL is supplied, keep only its path component.
raw_path = split.path if (split.scheme or split.netloc) else candidate
# Drop any query/fragment a caller may have inlined into the path string.
raw_path = raw_path.split("?", 1)[0].split("#", 1)[0]
raw_path = raw_path.replace("\\", "/")
segments = [seg for seg in raw_path.split("/") if seg and seg != "."]
if any(seg == ".." for seg in segments):
raise ValueError("path must not contain '..' traversal segments")
rel_segments = segments[2:] if segments[:2] == ["api", "v1"] else segments
if not rel_segments:
return "/api/v1"
return "/api/v1/" + "/".join(rel_segments)
def _raw_relative_segments(endpoint: str) -> list[str]:
"""Return the endpoint segments after the ``/api/v1`` prefix."""
segments = [seg for seg in endpoint.split("/") if seg]
return segments[2:] if segments[:2] == ["api", "v1"] else segments
def raw_relative_segments(endpoint: str) -> list[str]:
"""Return the endpoint path segments after the ``/api/v1`` prefix (public)."""
return _raw_relative_segments(endpoint)
def raw_top_segment(endpoint: str) -> str:
"""Return the first path segment after ``/api/v1`` for coarse policy grouping."""
rel = _raw_relative_segments(endpoint)
return rel[0] if rel else ""
def raw_method_is_write(method: str) -> bool:
"""Return whether an HTTP method mutates state."""
return method.upper() in _RAW_WRITE_METHODS
def raw_is_sensitive(endpoint: str) -> bool:
"""Return whether an endpoint touches an admin/credential surface."""
rel = _raw_relative_segments(endpoint)
if any(seg in _RAW_SENSITIVE_SEGMENTS for seg in rel):
return True
joined = "/".join(rel)
return any(sub in joined for sub in _RAW_SENSITIVE_SUBPATHS)
def _raw_repo_segments(endpoint: str) -> list[str] | None:
"""Return ``[owner, repo, *rest]`` for a single-repository endpoint, else None."""
rel = _raw_relative_segments(endpoint)
if len(rel) < 3 or rel[0] != "repos":
return None
owner, repo = rel[1], rel[2]
if owner in _RAW_CROSS_REPO_OWNERS:
return None
if not (re.match(_REPO_PART_PATTERN, owner) and re.match(_REPO_PART_PATTERN, repo)):
return None
return [owner, repo, *rel[3:]]
def parse_raw_repository(endpoint: str) -> str | None:
"""Parse ``owner/repo`` from a repo-scoped endpoint; None for cross-repo paths."""
repo_segments = _raw_repo_segments(endpoint)
if repo_segments is None:
return None
return f"{repo_segments[0]}/{repo_segments[1]}"
def parse_raw_target_path(endpoint: str) -> str | None:
"""Parse a file-path target from ``contents``/``raw``/``media`` endpoints."""
repo_segments = _raw_repo_segments(endpoint)
if repo_segments is None or len(repo_segments) < 4:
return None
if repo_segments[2] not in _RAW_FILE_RESOURCES:
return None
file_path = "/".join(repo_segments[3:])
return file_path or None
class RawApiRequestArgs(StrictBaseModel):
"""Arguments for the generic ``gitea_request`` escape-hatch tool."""
method: Literal["GET", "HEAD", "POST", "PUT", "PATCH", "DELETE"] = Field(
..., description="HTTP method"
)
path: str = Field(..., min_length=1, max_length=2048, description="Gitea REST path")
query: dict[str, Any] | None = Field(
default=None, description="Optional query-string parameters"
)
body: dict[str, Any] | None = Field(default=None, description="Optional JSON request body")
@field_validator("method", mode="before")
@classmethod
def _normalize_method(cls, value: object) -> object:
"""Uppercase the method before enum validation so 'get' is accepted."""
if isinstance(value, str):
return value.strip().upper()
return value
@model_validator(mode="after")
def _validate_path(self) -> RawApiRequestArgs:
"""Reject path traversal up front so the handler sees a clean endpoint."""
normalize_raw_endpoint(self.path)
return self
def extract_repository(arguments: dict[str, object]) -> str | None:
"""Extract `owner/repo` from raw argument mapping.
Args:
arguments: Raw tool arguments.
Returns:
`owner/repo` or None when arguments are incomplete.
"""
owner = arguments.get("owner")
repo = arguments.get("repo")
if isinstance(owner, str) and isinstance(repo, str) and owner and repo:
return f"{owner}/{repo}"
# Raw API dispatch: derive the repository from the request path so the central
# policy gate and the service-PAT per-user permission check evaluate the real
# target instead of treating every raw call as repo-less.
path = arguments.get("path")
method = arguments.get("method")
if isinstance(path, str) and isinstance(method, str):
try:
return parse_raw_repository(normalize_raw_endpoint(path))
except ValueError:
return None
return None
def extract_target_path(arguments: dict[str, object]) -> str | None:
"""Extract optional target path argument for policy path checks."""
filepath = arguments.get("filepath")
if isinstance(filepath, str) and filepath:
return filepath
# Raw API dispatch: expose the file path embedded in contents/raw/media
# endpoints so repository path allow/deny rules still apply to raw calls.
path = arguments.get("path")
method = arguments.get("method")
if isinstance(path, str) and isinstance(method, str):
try:
return parse_raw_target_path(normalize_raw_endpoint(path))
except ValueError:
return None
return None