"""Pydantic argument models for MCP tools.""" from __future__ import annotations from typing import Annotated, Literal from pydantic import AfterValidator, BaseModel, ConfigDict, Field, model_validator _REPO_PART_PATTERN = r"^[A-Za-z0-9._-]{1,100}$" def _validate_git_ref(value: str) -> str: """Validate a git ref-like value (ref/sha/base/head) against traversal. Refs that legitimately contain ``/`` (e.g. ``feature/foo``, ``release/1.0``) are preserved; only traversal and unsafe URL-path characters are rejected. Args: value: Candidate ref, sha, base, or head value. Returns: The unchanged value when it is safe. Raises: ValueError: When the value could escape the intended repository path. """ # Security decision: block path traversal and absolute references. if ".." in value.split("/"): raise ValueError("ref must not contain '..' path segments") if value.startswith("/"): raise ValueError("ref must not start with '/'") if "\\" in value: raise ValueError("ref must not contain backslashes") if "\x00" in value: raise ValueError("ref must not contain null bytes") if any(ord(char) < 0x20 or ord(char) == 0x7F for char in value): raise ValueError("ref must not contain control characters") if any(char.isspace() for char in value): raise ValueError("ref must not contain whitespace") if "?" in value or "#" in value: raise ValueError("ref must not contain '?' or '#'") return value GitRef = Annotated[str, AfterValidator(_validate_git_ref)] class StrictBaseModel(BaseModel): """Strict model base that rejects unexpected fields.""" model_config = ConfigDict(extra="forbid") class ListRepositoriesArgs(StrictBaseModel): """Arguments for list_repositories tool.""" class RepositoryArgs(StrictBaseModel): """Common repository locator arguments.""" owner: str = Field(..., pattern=_REPO_PART_PATTERN) repo: str = Field(..., pattern=_REPO_PART_PATTERN) class FileTreeArgs(RepositoryArgs): """Arguments for get_file_tree.""" ref: GitRef = Field(default="main", min_length=1, max_length=200) recursive: bool = Field(default=False) class FileContentsArgs(RepositoryArgs): """Arguments for get_file_contents.""" filepath: str = Field(..., min_length=1, max_length=1024) ref: GitRef = Field(default="main", min_length=1, max_length=200) @model_validator(mode="after") def validate_filepath(self) -> FileContentsArgs: """Validate path safety constraints.""" normalized = self.filepath.replace("\\", "/") # Security decision: block traversal and absolute paths. if normalized.startswith("/") or ".." in normalized.split("/"): raise ValueError("filepath must be a relative path without traversal") if "\x00" in normalized: raise ValueError("filepath cannot contain null bytes") return self class SearchCodeArgs(RepositoryArgs): """Arguments for search_code.""" query: str = Field(..., min_length=1, max_length=256) ref: GitRef = Field(default="main", min_length=1, max_length=200) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=25, ge=1, le=100) class ListCommitsArgs(RepositoryArgs): """Arguments for list_commits.""" ref: GitRef = Field(default="main", min_length=1, max_length=200) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=25, ge=1, le=100) class CommitDiffArgs(RepositoryArgs): """Arguments for get_commit_diff.""" sha: GitRef = Field(..., min_length=7, max_length=64) class CompareRefsArgs(RepositoryArgs): """Arguments for compare_refs.""" base: GitRef = Field(..., min_length=1, max_length=200) head: GitRef = Field(..., min_length=1, max_length=200) class ListIssuesArgs(RepositoryArgs): """Arguments for list_issues.""" state: Literal["open", "closed", "all"] = Field(default="open") page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=25, ge=1, le=100) labels: list[str] = Field(default_factory=list, max_length=20) class IssueArgs(RepositoryArgs): """Arguments for get_issue.""" issue_number: int = Field(..., ge=1) class ListPullRequestsArgs(RepositoryArgs): """Arguments for list_pull_requests.""" state: Literal["open", "closed", "all"] = Field(default="open") page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=25, ge=1, le=100) class PullRequestArgs(RepositoryArgs): """Arguments for get_pull_request.""" pull_number: int = Field(..., ge=1) class ListLabelsArgs(RepositoryArgs): """Arguments for list_labels.""" page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class ListTagsArgs(RepositoryArgs): """Arguments for list_tags.""" page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class ListReleasesArgs(RepositoryArgs): """Arguments for list_releases.""" page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=25, ge=1, le=100) class CreateIssueArgs(RepositoryArgs): """Arguments for create_issue.""" title: str = Field(..., min_length=1, max_length=256) body: str = Field(default="", max_length=20_000) labels: list[str] = Field(default_factory=list, max_length=20) assignees: list[str] = Field(default_factory=list, max_length=20) class UpdateIssueArgs(RepositoryArgs): """Arguments for update_issue.""" issue_number: int = Field(..., ge=1) title: str | None = Field(default=None, min_length=1, max_length=256) body: str | None = Field(default=None, max_length=20_000) state: Literal["open", "closed"] | None = Field(default=None) @model_validator(mode="after") def require_change(self) -> UpdateIssueArgs: """Require at least one mutable field in update payload.""" if self.title is None and self.body is None and self.state is None: raise ValueError("At least one of title, body, or state must be provided") return self class CreateIssueCommentArgs(RepositoryArgs): """Arguments for create_issue_comment.""" issue_number: int = Field(..., ge=1) body: str = Field(..., min_length=1, max_length=10_000) class CreatePrCommentArgs(RepositoryArgs): """Arguments for create_pr_comment.""" pull_number: int = Field(..., ge=1) body: str = Field(..., min_length=1, max_length=10_000) class AddLabelsArgs(RepositoryArgs): """Arguments for add_labels.""" issue_number: int = Field(..., ge=1) labels: list[str] = Field(..., min_length=1, max_length=20) class AssignIssueArgs(RepositoryArgs): """Arguments for assign_issue.""" issue_number: int = Field(..., ge=1) assignees: list[str] = Field(..., min_length=1, max_length=20) class CreateLabelArgs(RepositoryArgs): """Arguments for create_label.""" name: str = Field(..., min_length=1, max_length=50) # Gitea requires a hex color; accept it with or without a leading '#'. color: str = Field(..., pattern=r"^#?[0-9A-Fa-f]{6}$") description: str = Field(default="", max_length=1000) exclusive: bool = Field(default=False) class UpdateLabelArgs(RepositoryArgs): """Arguments for update_label (located by current name).""" name: str = Field(..., min_length=1, max_length=50) new_name: str | None = Field(default=None, min_length=1, max_length=50) color: str | None = Field(default=None, pattern=r"^#?[0-9A-Fa-f]{6}$") description: str | None = Field(default=None, max_length=1000) @model_validator(mode="after") def require_change(self) -> UpdateLabelArgs: """Require at least one mutable field in the update payload.""" if self.new_name is None and self.color is None and self.description is None: raise ValueError("At least one of new_name, color, or description must be provided") return self class RemoveLabelsArgs(RepositoryArgs): """Arguments for remove_labels.""" issue_number: int = Field(..., ge=1) labels: list[str] = Field(..., min_length=1, max_length=20) class CreatePullRequestArgs(RepositoryArgs): """Arguments for create_pull_request.""" title: str = Field(..., min_length=1, max_length=256) head: GitRef = Field(..., min_length=1, max_length=200) base: GitRef = Field(..., min_length=1, max_length=200) body: str = Field(default="", max_length=20_000) class CreateReleaseArgs(RepositoryArgs): """Arguments for create_release.""" tag_name: GitRef = Field(..., min_length=1, max_length=200) name: str = Field(default="", max_length=256) body: str = Field(default="", max_length=20_000) draft: bool = Field(default=False) prerelease: bool = Field(default=False) target: str | None = Field(default=None, min_length=1, max_length=200) class EditReleaseArgs(RepositoryArgs): """Arguments for edit_release.""" release_id: int = Field(..., ge=1) name: str | None = Field(default=None, max_length=256) body: str | None = Field(default=None, max_length=20_000) draft: bool | None = Field(default=None) prerelease: bool | None = Field(default=None) @model_validator(mode="after") def require_change(self) -> EditReleaseArgs: """Require at least one mutable field in the update payload.""" if ( self.name is None and self.body is None and self.draft is None and self.prerelease is None ): raise ValueError("At least one of name, body, draft, or prerelease must be provided") return self class CreateBranchArgs(RepositoryArgs): """Arguments for create_branch.""" new_branch_name: GitRef = Field(..., min_length=1, max_length=200) old_branch_name: str | None = Field(default=None, min_length=1, max_length=200) class CreateMilestoneArgs(RepositoryArgs): """Arguments for create_milestone.""" title: str = Field(..., min_length=1, max_length=256) description: str = Field(default="", max_length=10_000) due_on: str | None = Field(default=None, max_length=64) class EditIssueCommentArgs(RepositoryArgs): """Arguments for edit_issue_comment.""" comment_id: int = Field(..., ge=1) body: str = Field(..., min_length=1, max_length=10_000) class ListPullRequestFilesArgs(RepositoryArgs): """Arguments for list_pull_request_files.""" pull_number: int = Field(..., ge=1) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class ListPullRequestCommitsArgs(RepositoryArgs): """Arguments for list_pull_request_commits.""" pull_number: int = Field(..., ge=1) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class ListIssueCommentsArgs(RepositoryArgs): """Arguments for list_issue_comments.""" issue_number: int = Field(..., ge=1) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class ListBranchesArgs(RepositoryArgs): """Arguments for list_branches.""" page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class GetBranchArgs(RepositoryArgs): """Arguments for get_branch.""" branch: GitRef = Field(..., min_length=1, max_length=200) class GetReleaseArgs(RepositoryArgs): """Arguments for get_release.""" release_id: int = Field(..., ge=1) class LatestReleaseArgs(RepositoryArgs): """Arguments for get_latest_release.""" class ListMilestonesArgs(RepositoryArgs): """Arguments for list_milestones.""" state: Literal["open", "closed", "all"] = Field(default="open") page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class CommitStatusArgs(RepositoryArgs): """Arguments for get_commit_status.""" sha: GitRef = Field(..., min_length=1, max_length=64) class ListOrgRepositoriesArgs(StrictBaseModel): """Arguments for list_org_repositories.""" org: str = Field(..., pattern=_REPO_PART_PATTERN) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class ListOrganizationsArgs(StrictBaseModel): """Arguments for list_organizations.""" page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=50, ge=1, le=100) class RepoLanguagesArgs(RepositoryArgs): """Arguments for get_repo_languages.""" class RepoTopicsArgs(RepositoryArgs): """Arguments for list_repo_topics.""" def extract_repository(arguments: dict[str, object]) -> str | None: """Extract `owner/repo` from raw argument mapping. Args: arguments: Raw tool arguments. Returns: `owner/repo` or None when arguments are incomplete. """ owner = arguments.get("owner") repo = arguments.get("repo") if isinstance(owner, str) and isinstance(repo, str) and owner and repo: return f"{owner}/{repo}" return None def extract_target_path(arguments: dict[str, object]) -> str | None: """Extract optional target path argument for policy path checks.""" filepath = arguments.get("filepath") if isinstance(filepath, str) and filepath: return filepath return None