"""Pydantic argument models for MCP tools.""" from __future__ import annotations from typing import Annotated, Literal from pydantic import ( AfterValidator, BaseModel, BeforeValidator, ConfigDict, Field, model_validator, ) _REPO_PART_PATTERN = r"^[A-Za-z0-9._-]{1,100}$" def _validate_git_ref(value: str) -> str: """Validate a git ref-like value (ref/sha/base/head) against traversal. Refs that legitimately contain ``/`` (e.g. ``feature/foo``, ``release/1.0``) are preserved; only traversal and unsafe URL-path characters are rejected. Args: value: Candidate ref, sha, base, or head value. Returns: The unchanged value when it is safe. Raises: ValueError: When the value could escape the intended repository path. """ # Security decision: block path traversal and absolute references. if ".." in value.split("/"): raise ValueError("ref must not contain '..' path segments") if value.startswith("/"): raise ValueError("ref must not start with '/'") if "\\" in value: raise ValueError("ref must not contain backslashes") if "\x00" in value: raise ValueError("ref must not contain null bytes") if any(ord(char) < 0x20 or ord(char) == 0x7F for char in value): raise ValueError("ref must not contain control characters") if any(char.isspace() for char in value): raise ValueError("ref must not contain whitespace") if "?" in value or "#" in value: raise ValueError("ref must not contain '?' or '#'") return value GitRef = Annotated[str, AfterValidator(_validate_git_ref)] def _validate_milestone(value: object) -> int | str: """Validate a milestone reference supplied as a numeric id or a title. An integer is treated as a milestone id (``0`` clears the milestone on update); a string is treated as a milestone title to resolve. Runs as a ``BeforeValidator`` so ``bool`` (a subclass of ``int`` that Pydantic would otherwise coerce to ``1``/``0``) is rejected on the raw input. """ if isinstance(value, bool): raise ValueError("milestone must be a milestone id or title") if isinstance(value, int): if value < 0: raise ValueError("milestone id must be >= 0") return value if isinstance(value, str): title = value.strip() if not title: raise ValueError("milestone title must not be empty") if len(title) > 256: raise ValueError("milestone title must not exceed 256 characters") return title raise ValueError("milestone must be a milestone id or title") MilestoneRef = Annotated[int | str, BeforeValidator(_validate_milestone)] class StrictBaseModel(BaseModel): """Strict model base that rejects unexpected fields.""" model_config = ConfigDict(extra="forbid") class ListRepositoriesArgs(StrictBaseModel): """Arguments for list_repositories tool.""" class RepositoryArgs(StrictBaseModel): """Common repository locator arguments.""" owner: str = Field(..., pattern=_REPO_PART_PATTERN) repo: str = Field(..., pattern=_REPO_PART_PATTERN) class FileTreeArgs(RepositoryArgs): """Arguments for get_file_tree.""" ref: GitRef = Field(default="main", min_length=1, max_length=200) recursive: bool = Field(default=False) class FileContentsArgs(RepositoryArgs): """Arguments for get_file_contents.""" filepath: str = Field(..., min_length=1, max_length=1024) ref: GitRef = Field(default="main", min_length=1, max_length=200) @model_validator(mode="after") def validate_filepath(self) -> FileContentsArgs: """Validate path safety constraints.""" normalized = self.filepath.replace("\\", "/") # Security decision: block traversal and absolute paths. if normalized.startswith("/") or ".." in normalized.split("/"): raise ValueError("filepath must be a relative path without traversal") if "\x00" in normalized: raise ValueError("filepath cannot contain null bytes") return self class SearchCodeArgs(RepositoryArgs): """Arguments for search_code.""" query: str = Field(..., min_length=1, max_length=256) ref: GitRef = Field(default="main", min_length=1, max_length=200) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=25, ge=1, le=100) class ListCommitsArgs(RepositoryArgs): """Arguments for list_commits.""" ref: GitRef = Field(default="main", min_length=1, max_length=200) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=25, ge=1, le=100) class CommitDiffArgs(RepositoryArgs): """Arguments for get_commit_diff.""" sha: GitRef = Field(..., min_length=7, max_length=64) class CompareRefsArgs(RepositoryArgs): """Arguments for compare_refs.""" base: GitRef = Field(..., min_length=1, max_length=200) head: GitRef = Field(..., min_length=1, max_length=200) class ListIssuesArgs(RepositoryArgs): """Arguments for list_issues.""" state: Literal["open", "closed", "all"] = Field(default="open") page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=25, ge=1, le=100) labels: list[str] = Field(default_factory=list, max_length=20) class IssueArgs(RepositoryArgs): """Arguments for get_issue.""" issue_number: int = Field(..., ge=1) class ListPullRequestsArgs(RepositoryArgs): """Arguments for list_pull_requests.""" state: Literal["open", "closed", "all"] = Field(default="open") page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=25, ge=1, le=100) class PullRequestArgs(RepositoryArgs): """Arguments for get_pull_request.""" pull_number: int = Field(..., ge=1) class ListLabelsArgs(RepositoryArgs): """Arguments for list_labels.""" page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class ListTagsArgs(RepositoryArgs): """Arguments for list_tags.""" page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class ListReleasesArgs(RepositoryArgs): """Arguments for list_releases.""" page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=25, ge=1, le=100) class CreateIssueArgs(RepositoryArgs): """Arguments for create_issue.""" title: str = Field(..., min_length=1, max_length=256) body: str = Field(default="", max_length=20_000) labels: list[str] = Field(default_factory=list, max_length=20) assignees: list[str] = Field(default_factory=list, max_length=20) milestone: MilestoneRef | None = Field( default=None, description="Milestone id or title to assign the issue to" ) class UpdateIssueArgs(RepositoryArgs): """Arguments for update_issue.""" issue_number: int = Field(..., ge=1) title: str | None = Field(default=None, min_length=1, max_length=256) body: str | None = Field(default=None, max_length=20_000) state: Literal["open", "closed"] | None = Field(default=None) milestone: MilestoneRef | None = Field( default=None, description="Milestone id or title to assign; 0 clears the milestone" ) @model_validator(mode="after") def require_change(self) -> UpdateIssueArgs: """Require at least one mutable field in update payload.""" if ( self.title is None and self.body is None and self.state is None and self.milestone is None ): raise ValueError("At least one of title, body, state, or milestone must be provided") return self class CreateIssueCommentArgs(RepositoryArgs): """Arguments for create_issue_comment.""" issue_number: int = Field(..., ge=1) body: str = Field(..., min_length=1, max_length=10_000) class CreatePrCommentArgs(RepositoryArgs): """Arguments for create_pr_comment.""" pull_number: int = Field(..., ge=1) body: str = Field(..., min_length=1, max_length=10_000) class AddLabelsArgs(RepositoryArgs): """Arguments for add_labels.""" issue_number: int = Field(..., ge=1) labels: list[str] = Field(..., min_length=1, max_length=20) class AssignIssueArgs(RepositoryArgs): """Arguments for assign_issue.""" issue_number: int = Field(..., ge=1) assignees: list[str] = Field(..., min_length=1, max_length=20) class CreateLabelArgs(RepositoryArgs): """Arguments for create_label.""" name: str = Field(..., min_length=1, max_length=50) # Gitea requires a hex color; accept it with or without a leading '#'. color: str = Field(..., pattern=r"^#?[0-9A-Fa-f]{6}$") description: str = Field(default="", max_length=1000) exclusive: bool = Field(default=False) class UpdateLabelArgs(RepositoryArgs): """Arguments for update_label (located by current name).""" name: str = Field(..., min_length=1, max_length=50) new_name: str | None = Field(default=None, min_length=1, max_length=50) color: str | None = Field(default=None, pattern=r"^#?[0-9A-Fa-f]{6}$") description: str | None = Field(default=None, max_length=1000) @model_validator(mode="after") def require_change(self) -> UpdateLabelArgs: """Require at least one mutable field in the update payload.""" if self.new_name is None and self.color is None and self.description is None: raise ValueError("At least one of new_name, color, or description must be provided") return self class RemoveLabelsArgs(RepositoryArgs): """Arguments for remove_labels.""" issue_number: int = Field(..., ge=1) labels: list[str] = Field(..., min_length=1, max_length=20) class CreatePullRequestArgs(RepositoryArgs): """Arguments for create_pull_request.""" title: str = Field(..., min_length=1, max_length=256) head: GitRef = Field(..., min_length=1, max_length=200) base: GitRef = Field(..., min_length=1, max_length=200) body: str = Field(default="", max_length=20_000) class CreateReleaseArgs(RepositoryArgs): """Arguments for create_release.""" tag_name: GitRef = Field(..., min_length=1, max_length=200) name: str = Field(default="", max_length=256) body: str = Field(default="", max_length=20_000) draft: bool = Field(default=False) prerelease: bool = Field(default=False) target: str | None = Field(default=None, min_length=1, max_length=200) class EditReleaseArgs(RepositoryArgs): """Arguments for edit_release.""" release_id: int = Field(..., ge=1) name: str | None = Field(default=None, max_length=256) body: str | None = Field(default=None, max_length=20_000) draft: bool | None = Field(default=None) prerelease: bool | None = Field(default=None) @model_validator(mode="after") def require_change(self) -> EditReleaseArgs: """Require at least one mutable field in the update payload.""" if ( self.name is None and self.body is None and self.draft is None and self.prerelease is None ): raise ValueError("At least one of name, body, draft, or prerelease must be provided") return self class CreateBranchArgs(RepositoryArgs): """Arguments for create_branch.""" new_branch_name: GitRef = Field(..., min_length=1, max_length=200) old_branch_name: str | None = Field(default=None, min_length=1, max_length=200) class CreateMilestoneArgs(RepositoryArgs): """Arguments for create_milestone.""" title: str = Field(..., min_length=1, max_length=256) description: str = Field(default="", max_length=10_000) due_on: str | None = Field(default=None, max_length=64) class EditIssueCommentArgs(RepositoryArgs): """Arguments for edit_issue_comment.""" comment_id: int = Field(..., ge=1) body: str = Field(..., min_length=1, max_length=10_000) class ListPullRequestFilesArgs(RepositoryArgs): """Arguments for list_pull_request_files.""" pull_number: int = Field(..., ge=1) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class ListPullRequestCommitsArgs(RepositoryArgs): """Arguments for list_pull_request_commits.""" pull_number: int = Field(..., ge=1) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class ListIssueCommentsArgs(RepositoryArgs): """Arguments for list_issue_comments.""" issue_number: int = Field(..., ge=1) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class ListBranchesArgs(RepositoryArgs): """Arguments for list_branches.""" page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class GetBranchArgs(RepositoryArgs): """Arguments for get_branch.""" branch: GitRef = Field(..., min_length=1, max_length=200) class GetReleaseArgs(RepositoryArgs): """Arguments for get_release.""" release_id: int = Field(..., ge=1) class LatestReleaseArgs(RepositoryArgs): """Arguments for get_latest_release.""" class ListMilestonesArgs(RepositoryArgs): """Arguments for list_milestones.""" state: Literal["open", "closed", "all"] = Field(default="open") page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class CommitStatusArgs(RepositoryArgs): """Arguments for get_commit_status.""" sha: GitRef = Field(..., min_length=1, max_length=64) class ListOrgRepositoriesArgs(StrictBaseModel): """Arguments for list_org_repositories.""" org: str = Field(..., pattern=_REPO_PART_PATTERN) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class ListOrganizationsArgs(StrictBaseModel): """Arguments for list_organizations.""" page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class RepoLanguagesArgs(RepositoryArgs): """Arguments for get_repo_languages.""" class RepoTopicsArgs(RepositoryArgs): """Arguments for list_repo_topics.""" def extract_repository(arguments: dict[str, object]) -> str | None: """Extract `owner/repo` from raw argument mapping. Args: arguments: Raw tool arguments. Returns: `owner/repo` or None when arguments are incomplete. """ owner = arguments.get("owner") repo = arguments.get("repo") if isinstance(owner, str) and isinstance(repo, str) and owner and repo: return f"{owner}/{repo}" return None def extract_target_path(arguments: dict[str, object]) -> str | None: """Extract optional target path argument for policy path checks.""" filepath = arguments.get("filepath") if isinstance(filepath, str) and filepath: return filepath return None