"""Pydantic argument models for MCP tools.""" from __future__ import annotations import re from typing import Annotated, Any, Literal from urllib.parse import urlsplit from pydantic import ( AfterValidator, BaseModel, BeforeValidator, ConfigDict, Field, field_validator, model_validator, ) _REPO_PART_PATTERN = r"^[A-Za-z0-9._-]{1,100}$" def _validate_git_ref(value: str) -> str: """Validate a git ref-like value (ref/sha/base/head) against traversal. Refs that legitimately contain ``/`` (e.g. ``feature/foo``, ``release/1.0``) are preserved; only traversal and unsafe URL-path characters are rejected. Args: value: Candidate ref, sha, base, or head value. Returns: The unchanged value when it is safe. Raises: ValueError: When the value could escape the intended repository path. """ # Security decision: block path traversal and absolute references. if ".." in value.split("/"): raise ValueError("ref must not contain '..' path segments") if value.startswith("/"): raise ValueError("ref must not start with '/'") if "\\" in value: raise ValueError("ref must not contain backslashes") if "\x00" in value: raise ValueError("ref must not contain null bytes") if any(ord(char) < 0x20 or ord(char) == 0x7F for char in value): raise ValueError("ref must not contain control characters") if any(char.isspace() for char in value): raise ValueError("ref must not contain whitespace") if "?" in value or "#" in value: raise ValueError("ref must not contain '?' or '#'") return value GitRef = Annotated[str, AfterValidator(_validate_git_ref)] def _validate_milestone(value: object) -> int | str: """Validate a milestone reference supplied as a numeric id or a title. An integer is treated as a milestone id (``0`` clears the milestone on update); a string is treated as a milestone title to resolve. Runs as a ``BeforeValidator`` so ``bool`` (a subclass of ``int`` that Pydantic would otherwise coerce to ``1``/``0``) is rejected on the raw input. """ if isinstance(value, bool): raise ValueError("milestone must be a milestone id or title") if isinstance(value, int): if value < 0: raise ValueError("milestone id must be >= 0") return value if isinstance(value, str): title = value.strip() if not title: raise ValueError("milestone title must not be empty") if len(title) > 256: raise ValueError("milestone title must not exceed 256 characters") return title raise ValueError("milestone must be a milestone id or title") MilestoneRef = Annotated[int | str, BeforeValidator(_validate_milestone)] class StrictBaseModel(BaseModel): """Strict model base that rejects unexpected fields.""" model_config = ConfigDict(extra="forbid") class ListRepositoriesArgs(StrictBaseModel): """Arguments for list_repositories tool.""" class RepositoryArgs(StrictBaseModel): """Common repository locator arguments.""" owner: str = Field(..., pattern=_REPO_PART_PATTERN) repo: str = Field(..., pattern=_REPO_PART_PATTERN) class FileTreeArgs(RepositoryArgs): """Arguments for get_file_tree.""" ref: GitRef = Field(default="main", min_length=1, max_length=200) recursive: bool = Field(default=False) class FileContentsArgs(RepositoryArgs): """Arguments for get_file_contents.""" filepath: str = Field(..., min_length=1, max_length=1024) ref: GitRef = Field(default="main", min_length=1, max_length=200) @model_validator(mode="after") def validate_filepath(self) -> FileContentsArgs: """Validate path safety constraints.""" normalized = self.filepath.replace("\\", "/") # Security decision: block traversal and absolute paths. if normalized.startswith("/") or ".." in normalized.split("/"): raise ValueError("filepath must be a relative path without traversal") if "\x00" in normalized: raise ValueError("filepath cannot contain null bytes") return self class SearchCodeArgs(RepositoryArgs): """Arguments for search_code.""" query: str = Field(..., min_length=1, max_length=256) ref: GitRef = Field(default="main", min_length=1, max_length=200) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=25, ge=1, le=100) class ListCommitsArgs(RepositoryArgs): """Arguments for list_commits.""" ref: GitRef = Field(default="main", min_length=1, max_length=200) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=25, ge=1, le=100) class CommitDiffArgs(RepositoryArgs): """Arguments for get_commit_diff.""" sha: GitRef = Field(..., min_length=7, max_length=64) class CompareRefsArgs(RepositoryArgs): """Arguments for compare_refs.""" base: GitRef = Field(..., min_length=1, max_length=200) head: GitRef = Field(..., min_length=1, max_length=200) class ListIssuesArgs(RepositoryArgs): """Arguments for list_issues.""" state: Literal["open", "closed", "all"] = Field(default="open") page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=25, ge=1, le=100) labels: list[str] = Field(default_factory=list, max_length=20) class IssueArgs(RepositoryArgs): """Arguments for get_issue.""" issue_number: int = Field(..., ge=1) class ListPullRequestsArgs(RepositoryArgs): """Arguments for list_pull_requests.""" state: Literal["open", "closed", "all"] = Field(default="open") page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=25, ge=1, le=100) class PullRequestArgs(RepositoryArgs): """Arguments for get_pull_request.""" pull_number: int = Field(..., ge=1) class ListLabelsArgs(RepositoryArgs): """Arguments for list_labels.""" page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class ListTagsArgs(RepositoryArgs): """Arguments for list_tags.""" page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class ListReleasesArgs(RepositoryArgs): """Arguments for list_releases.""" page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=25, ge=1, le=100) class CreateIssueArgs(RepositoryArgs): """Arguments for create_issue.""" title: str = Field(..., min_length=1, max_length=256) body: str = Field(default="", max_length=20_000) labels: list[str] = Field(default_factory=list, max_length=20) assignees: list[str] = Field(default_factory=list, max_length=20) milestone: MilestoneRef | None = Field( default=None, description="Milestone id or title to assign the issue to" ) class UpdateIssueArgs(RepositoryArgs): """Arguments for update_issue.""" issue_number: int = Field(..., ge=1) title: str | None = Field(default=None, min_length=1, max_length=256) body: str | None = Field(default=None, max_length=20_000) state: Literal["open", "closed"] | None = Field(default=None) milestone: MilestoneRef | None = Field( default=None, description="Milestone id or title to assign; 0 clears the milestone" ) @model_validator(mode="after") def require_change(self) -> UpdateIssueArgs: """Require at least one mutable field in update payload.""" if ( self.title is None and self.body is None and self.state is None and self.milestone is None ): raise ValueError("At least one of title, body, state, or milestone must be provided") return self class CreateIssueCommentArgs(RepositoryArgs): """Arguments for create_issue_comment.""" issue_number: int = Field(..., ge=1) body: str = Field(..., min_length=1, max_length=10_000) class CreatePrCommentArgs(RepositoryArgs): """Arguments for create_pr_comment.""" pull_number: int = Field(..., ge=1) body: str = Field(..., min_length=1, max_length=10_000) class AddLabelsArgs(RepositoryArgs): """Arguments for add_labels.""" issue_number: int = Field(..., ge=1) labels: list[str] = Field(..., min_length=1, max_length=20) class AssignIssueArgs(RepositoryArgs): """Arguments for assign_issue.""" issue_number: int = Field(..., ge=1) assignees: list[str] = Field(..., min_length=1, max_length=20) class CreateLabelArgs(RepositoryArgs): """Arguments for create_label.""" name: str = Field(..., min_length=1, max_length=50) # Gitea requires a hex color; accept it with or without a leading '#'. color: str = Field(..., pattern=r"^#?[0-9A-Fa-f]{6}$") description: str = Field(default="", max_length=1000) exclusive: bool = Field(default=False) class UpdateLabelArgs(RepositoryArgs): """Arguments for update_label (located by current name).""" name: str = Field(..., min_length=1, max_length=50) new_name: str | None = Field(default=None, min_length=1, max_length=50) color: str | None = Field(default=None, pattern=r"^#?[0-9A-Fa-f]{6}$") description: str | None = Field(default=None, max_length=1000) @model_validator(mode="after") def require_change(self) -> UpdateLabelArgs: """Require at least one mutable field in the update payload.""" if self.new_name is None and self.color is None and self.description is None: raise ValueError("At least one of new_name, color, or description must be provided") return self class RemoveLabelsArgs(RepositoryArgs): """Arguments for remove_labels.""" issue_number: int = Field(..., ge=1) labels: list[str] = Field(..., min_length=1, max_length=20) class CreatePullRequestArgs(RepositoryArgs): """Arguments for create_pull_request.""" title: str = Field(..., min_length=1, max_length=256) head: GitRef = Field(..., min_length=1, max_length=200) base: GitRef = Field(..., min_length=1, max_length=200) body: str = Field(default="", max_length=20_000) class CreateReleaseArgs(RepositoryArgs): """Arguments for create_release.""" tag_name: GitRef = Field(..., min_length=1, max_length=200) name: str = Field(default="", max_length=256) body: str = Field(default="", max_length=20_000) draft: bool = Field(default=False) prerelease: bool = Field(default=False) target: str | None = Field(default=None, min_length=1, max_length=200) class EditReleaseArgs(RepositoryArgs): """Arguments for edit_release.""" release_id: int = Field(..., ge=1) name: str | None = Field(default=None, max_length=256) body: str | None = Field(default=None, max_length=20_000) draft: bool | None = Field(default=None) prerelease: bool | None = Field(default=None) @model_validator(mode="after") def require_change(self) -> EditReleaseArgs: """Require at least one mutable field in the update payload.""" if ( self.name is None and self.body is None and self.draft is None and self.prerelease is None ): raise ValueError("At least one of name, body, draft, or prerelease must be provided") return self class CreateBranchArgs(RepositoryArgs): """Arguments for create_branch.""" new_branch_name: GitRef = Field(..., min_length=1, max_length=200) old_branch_name: str | None = Field(default=None, min_length=1, max_length=200) class CreateMilestoneArgs(RepositoryArgs): """Arguments for create_milestone.""" title: str = Field(..., min_length=1, max_length=256) description: str = Field(default="", max_length=10_000) due_on: str | None = Field(default=None, max_length=64) class EditIssueCommentArgs(RepositoryArgs): """Arguments for edit_issue_comment.""" comment_id: int = Field(..., ge=1) body: str = Field(..., min_length=1, max_length=10_000) class ListPullRequestFilesArgs(RepositoryArgs): """Arguments for list_pull_request_files.""" pull_number: int = Field(..., ge=1) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class ListPullRequestCommitsArgs(RepositoryArgs): """Arguments for list_pull_request_commits.""" pull_number: int = Field(..., ge=1) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class ListIssueCommentsArgs(RepositoryArgs): """Arguments for list_issue_comments.""" issue_number: int = Field(..., ge=1) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class ListBranchesArgs(RepositoryArgs): """Arguments for list_branches.""" page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class GetBranchArgs(RepositoryArgs): """Arguments for get_branch.""" branch: GitRef = Field(..., min_length=1, max_length=200) class GetReleaseArgs(RepositoryArgs): """Arguments for get_release.""" release_id: int = Field(..., ge=1) class LatestReleaseArgs(RepositoryArgs): """Arguments for get_latest_release.""" class ListMilestonesArgs(RepositoryArgs): """Arguments for list_milestones.""" state: Literal["open", "closed", "all"] = Field(default="open") page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class CommitStatusArgs(RepositoryArgs): """Arguments for get_commit_status.""" sha: GitRef = Field(..., min_length=1, max_length=64) class ListOrgRepositoriesArgs(StrictBaseModel): """Arguments for list_org_repositories.""" org: str = Field(..., pattern=_REPO_PART_PATTERN) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class ListOrganizationsArgs(StrictBaseModel): """Arguments for list_organizations.""" page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class RepoLanguagesArgs(RepositoryArgs): """Arguments for get_repo_languages.""" class RepoTopicsArgs(RepositoryArgs): """Arguments for list_repo_topics.""" # --- Raw API dispatch (gitea_request escape hatch) ------------------------- # HTTP methods the generic dispatch tool accepts. Everything outside GET/HEAD is # treated as a write so the policy/write-mode gate applies. RAW_API_METHODS = ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE") _RAW_WRITE_METHODS = frozenset({"POST", "PUT", "PATCH", "DELETE"}) # Path segments/subpaths blocked for *every* method unless explicitly overridden # via RAW_API_ALLOW_SENSITIVE. A GET on these already leaks credentials or # privileged configuration, so they are denied independently of policy.yaml. _RAW_SENSITIVE_SEGMENTS = frozenset({"admin", "tokens", "secrets", "hooks", "keys", "gpg_keys"}) _RAW_SENSITIVE_SUBPATHS = ("applications/oauth2", "actions/runners/registration-token") # Endpoints under /repos/ that are not scoped to a single repository. _RAW_CROSS_REPO_OWNERS = frozenset({"search", "issues"}) # Resources whose trailing segments form a file path target for policy checks. _RAW_FILE_RESOURCES = frozenset({"contents", "raw", "media"}) # Known top-level segments of the Gitea ``/api/v1`` surface. A raw request whose # first path segment is not in this set is rejected (fail closed): we never pass # an unrecognized path straight through to Gitea. KNOWN_API_PREFIXES = frozenset( { "activitypub", "admin", "gitignore", "issues", "label", "licenses", "markdown", "markup", "miscellaneous", "nodeinfo", "notifications", "org", "orgs", "packages", "repos", "repositories", "settings", "signing-key.gpg", "teams", "topics", "user", "users", "version", } ) # Override table: provably side-effect-free POSTs that may be treated as reads so # they do not needlessly require WRITE_MODE. This table may ONLY ever DOWNGRADE a # write to a read for endpoints that render content and mutate nothing — never # the reverse. Keyed by the final path segment of the endpoint. _RAW_READ_ONLY_POST_LEAVES = frozenset({"markdown", "markup", "raw"}) def raw_is_known_api_path(endpoint: str) -> bool: """Return whether the endpoint's top segment is a known Gitea API prefix.""" return raw_top_segment(endpoint) in KNOWN_API_PREFIXES def raw_request_is_write(method: str, endpoint: str) -> bool: """Classify a raw request as read or write from its method and path. ``GET``/``HEAD`` are reads; every other method is a write — except for the small, explicit override table of render-only POSTs (e.g. markdown/markup), which are reads. The override can only make a request *more* permissive for provably side-effect-free endpoints; it never reclassifies a mutating call as a read, so a misclassified write cannot slip past the write-mode gate. """ upper = method.upper() if upper in {"GET", "HEAD"}: return False if upper == "POST": rel = _raw_relative_segments(endpoint) if rel and rel[-1] in _RAW_READ_ONLY_POST_LEAVES: return False return True def normalize_raw_endpoint(path: str) -> str: """Normalize a raw API path into an ``/api/v1``-prefixed endpoint. Accepts a bare path (``/repos/o/r``), an already-prefixed path (``/api/v1/repos/o/r``), or a full URL (the scheme/host and any query string are stripped — the separate ``query`` argument carries query parameters). Raises: ValueError: When the path contains a ``..`` traversal segment. """ candidate = path.strip() split = urlsplit(candidate) # When a full URL is supplied, keep only its path component. raw_path = split.path if (split.scheme or split.netloc) else candidate # Drop any query/fragment a caller may have inlined into the path string. raw_path = raw_path.split("?", 1)[0].split("#", 1)[0] raw_path = raw_path.replace("\\", "/") segments = [seg for seg in raw_path.split("/") if seg and seg != "."] if any(seg == ".." for seg in segments): raise ValueError("path must not contain '..' traversal segments") rel_segments = segments[2:] if segments[:2] == ["api", "v1"] else segments if not rel_segments: return "/api/v1" return "/api/v1/" + "/".join(rel_segments) def _raw_relative_segments(endpoint: str) -> list[str]: """Return the endpoint segments after the ``/api/v1`` prefix.""" segments = [seg for seg in endpoint.split("/") if seg] return segments[2:] if segments[:2] == ["api", "v1"] else segments def raw_relative_segments(endpoint: str) -> list[str]: """Return the endpoint path segments after the ``/api/v1`` prefix (public).""" return _raw_relative_segments(endpoint) def raw_top_segment(endpoint: str) -> str: """Return the first path segment after ``/api/v1`` for coarse policy grouping.""" rel = _raw_relative_segments(endpoint) return rel[0] if rel else "" def raw_method_is_write(method: str) -> bool: """Return whether an HTTP method mutates state.""" return method.upper() in _RAW_WRITE_METHODS def raw_is_sensitive(endpoint: str) -> bool: """Return whether an endpoint touches an admin/credential surface.""" rel = _raw_relative_segments(endpoint) if any(seg in _RAW_SENSITIVE_SEGMENTS for seg in rel): return True joined = "/".join(rel) return any(sub in joined for sub in _RAW_SENSITIVE_SUBPATHS) def _raw_repo_segments(endpoint: str) -> list[str] | None: """Return ``[owner, repo, *rest]`` for a single-repository endpoint, else None.""" rel = _raw_relative_segments(endpoint) if len(rel) < 3 or rel[0] != "repos": return None owner, repo = rel[1], rel[2] if owner in _RAW_CROSS_REPO_OWNERS: return None if not (re.match(_REPO_PART_PATTERN, owner) and re.match(_REPO_PART_PATTERN, repo)): return None return [owner, repo, *rel[3:]] def parse_raw_repository(endpoint: str) -> str | None: """Parse ``owner/repo`` from a repo-scoped endpoint; None for cross-repo paths.""" repo_segments = _raw_repo_segments(endpoint) if repo_segments is None: return None return f"{repo_segments[0]}/{repo_segments[1]}" def parse_raw_target_path(endpoint: str) -> str | None: """Parse a file-path target from ``contents``/``raw``/``media`` endpoints.""" repo_segments = _raw_repo_segments(endpoint) if repo_segments is None or len(repo_segments) < 4: return None if repo_segments[2] not in _RAW_FILE_RESOURCES: return None file_path = "/".join(repo_segments[3:]) return file_path or None class RawApiRequestArgs(StrictBaseModel): """Arguments for the generic ``gitea_request`` escape-hatch tool.""" method: Literal["GET", "HEAD", "POST", "PUT", "PATCH", "DELETE"] = Field( ..., description="HTTP method" ) path: str = Field(..., min_length=1, max_length=2048, description="Gitea REST path") query: dict[str, Any] | None = Field( default=None, description="Optional query-string parameters" ) body: dict[str, Any] | None = Field(default=None, description="Optional JSON request body") @field_validator("method", mode="before") @classmethod def _normalize_method(cls, value: object) -> object: """Uppercase the method before enum validation so 'get' is accepted.""" if isinstance(value, str): return value.strip().upper() return value @model_validator(mode="after") def _validate_path(self) -> RawApiRequestArgs: """Reject path traversal up front so the handler sees a clean endpoint.""" normalize_raw_endpoint(self.path) return self def extract_repository(arguments: dict[str, object]) -> str | None: """Extract `owner/repo` from raw argument mapping. Args: arguments: Raw tool arguments. Returns: `owner/repo` or None when arguments are incomplete. """ owner = arguments.get("owner") repo = arguments.get("repo") if isinstance(owner, str) and isinstance(repo, str) and owner and repo: return f"{owner}/{repo}" # Raw API dispatch: derive the repository from the request path so the central # policy gate and the service-PAT per-user permission check evaluate the real # target instead of treating every raw call as repo-less. path = arguments.get("path") method = arguments.get("method") if isinstance(path, str) and isinstance(method, str): try: return parse_raw_repository(normalize_raw_endpoint(path)) except ValueError: return None return None def extract_target_path(arguments: dict[str, object]) -> str | None: """Extract optional target path argument for policy path checks.""" filepath = arguments.get("filepath") if isinstance(filepath, str) and filepath: return filepath # Raw API dispatch: expose the file path embedded in contents/raw/media # endpoints so repository path allow/deny rules still apply to raw calls. path = arguments.get("path") method = arguments.get("method") if isinstance(path, str) and isinstance(method, str): try: return parse_raw_target_path(normalize_raw_endpoint(path)) except ValueError: return None return None