Files
AegisGitea-MCP/src/aegis_gitea_mcp/tools/arguments.py
T
Latte e08ba42697
docker / test (pull_request) Successful in 29s
docker / lint (pull_request) Successful in 35s
lint / lint (pull_request) Successful in 35s
test / test (pull_request) Successful in 35s
docker / docker-test (pull_request) Successful in 8s
docker / docker-publish (pull_request) Has been skipped
test / test (push) Successful in 23s
lint / lint (push) Successful in 23s
feat: assign issues to milestones on create/update (#22)
Add a `milestone` argument to `create_issue` and `update_issue` accepting
either a numeric milestone id or a title (resolved case-insensitively against
open and closed milestones, with a clear error for unknown titles). On
`update_issue`, `milestone: 0` clears the milestone. A BeforeValidator rejects
booleans so they are not silently coerced to an id.

Gitea Projects (Kanban boards) were investigated for #22 and are intentionally
left unsupported: Gitea 1.26.2 exposes no project endpoints in its REST API.
Documented this in api-reference.md and refreshed the (stale) write-mode tool
list to cover all 16 write tools.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-22 17:36:01 +02:00

471 lines
15 KiB
Python

"""Pydantic argument models for MCP tools."""
from __future__ import annotations
from typing import Annotated, Literal
from pydantic import (
AfterValidator,
BaseModel,
BeforeValidator,
ConfigDict,
Field,
model_validator,
)
_REPO_PART_PATTERN = r"^[A-Za-z0-9._-]{1,100}$"
def _validate_git_ref(value: str) -> str:
"""Validate a git ref-like value (ref/sha/base/head) against traversal.
Refs that legitimately contain ``/`` (e.g. ``feature/foo``, ``release/1.0``)
are preserved; only traversal and unsafe URL-path characters are rejected.
Args:
value: Candidate ref, sha, base, or head value.
Returns:
The unchanged value when it is safe.
Raises:
ValueError: When the value could escape the intended repository path.
"""
# Security decision: block path traversal and absolute references.
if ".." in value.split("/"):
raise ValueError("ref must not contain '..' path segments")
if value.startswith("/"):
raise ValueError("ref must not start with '/'")
if "\\" in value:
raise ValueError("ref must not contain backslashes")
if "\x00" in value:
raise ValueError("ref must not contain null bytes")
if any(ord(char) < 0x20 or ord(char) == 0x7F for char in value):
raise ValueError("ref must not contain control characters")
if any(char.isspace() for char in value):
raise ValueError("ref must not contain whitespace")
if "?" in value or "#" in value:
raise ValueError("ref must not contain '?' or '#'")
return value
GitRef = Annotated[str, AfterValidator(_validate_git_ref)]
def _validate_milestone(value: object) -> int | str:
"""Validate a milestone reference supplied as a numeric id or a title.
An integer is treated as a milestone id (``0`` clears the milestone on
update); a string is treated as a milestone title to resolve. Runs as a
``BeforeValidator`` so ``bool`` (a subclass of ``int`` that Pydantic would
otherwise coerce to ``1``/``0``) is rejected on the raw input.
"""
if isinstance(value, bool):
raise ValueError("milestone must be a milestone id or title")
if isinstance(value, int):
if value < 0:
raise ValueError("milestone id must be >= 0")
return value
if isinstance(value, str):
title = value.strip()
if not title:
raise ValueError("milestone title must not be empty")
if len(title) > 256:
raise ValueError("milestone title must not exceed 256 characters")
return title
raise ValueError("milestone must be a milestone id or title")
MilestoneRef = Annotated[int | str, BeforeValidator(_validate_milestone)]
class StrictBaseModel(BaseModel):
"""Strict model base that rejects unexpected fields."""
model_config = ConfigDict(extra="forbid")
class ListRepositoriesArgs(StrictBaseModel):
"""Arguments for list_repositories tool."""
class RepositoryArgs(StrictBaseModel):
"""Common repository locator arguments."""
owner: str = Field(..., pattern=_REPO_PART_PATTERN)
repo: str = Field(..., pattern=_REPO_PART_PATTERN)
class FileTreeArgs(RepositoryArgs):
"""Arguments for get_file_tree."""
ref: GitRef = Field(default="main", min_length=1, max_length=200)
recursive: bool = Field(default=False)
class FileContentsArgs(RepositoryArgs):
"""Arguments for get_file_contents."""
filepath: str = Field(..., min_length=1, max_length=1024)
ref: GitRef = Field(default="main", min_length=1, max_length=200)
@model_validator(mode="after")
def validate_filepath(self) -> FileContentsArgs:
"""Validate path safety constraints."""
normalized = self.filepath.replace("\\", "/")
# Security decision: block traversal and absolute paths.
if normalized.startswith("/") or ".." in normalized.split("/"):
raise ValueError("filepath must be a relative path without traversal")
if "\x00" in normalized:
raise ValueError("filepath cannot contain null bytes")
return self
class SearchCodeArgs(RepositoryArgs):
"""Arguments for search_code."""
query: str = Field(..., min_length=1, max_length=256)
ref: GitRef = Field(default="main", min_length=1, max_length=200)
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=25, ge=1, le=100)
class ListCommitsArgs(RepositoryArgs):
"""Arguments for list_commits."""
ref: GitRef = Field(default="main", min_length=1, max_length=200)
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=25, ge=1, le=100)
class CommitDiffArgs(RepositoryArgs):
"""Arguments for get_commit_diff."""
sha: GitRef = Field(..., min_length=7, max_length=64)
class CompareRefsArgs(RepositoryArgs):
"""Arguments for compare_refs."""
base: GitRef = Field(..., min_length=1, max_length=200)
head: GitRef = Field(..., min_length=1, max_length=200)
class ListIssuesArgs(RepositoryArgs):
"""Arguments for list_issues."""
state: Literal["open", "closed", "all"] = Field(default="open")
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=25, ge=1, le=100)
labels: list[str] = Field(default_factory=list, max_length=20)
class IssueArgs(RepositoryArgs):
"""Arguments for get_issue."""
issue_number: int = Field(..., ge=1)
class ListPullRequestsArgs(RepositoryArgs):
"""Arguments for list_pull_requests."""
state: Literal["open", "closed", "all"] = Field(default="open")
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=25, ge=1, le=100)
class PullRequestArgs(RepositoryArgs):
"""Arguments for get_pull_request."""
pull_number: int = Field(..., ge=1)
class ListLabelsArgs(RepositoryArgs):
"""Arguments for list_labels."""
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class ListTagsArgs(RepositoryArgs):
"""Arguments for list_tags."""
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class ListReleasesArgs(RepositoryArgs):
"""Arguments for list_releases."""
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=25, ge=1, le=100)
class CreateIssueArgs(RepositoryArgs):
"""Arguments for create_issue."""
title: str = Field(..., min_length=1, max_length=256)
body: str = Field(default="", max_length=20_000)
labels: list[str] = Field(default_factory=list, max_length=20)
assignees: list[str] = Field(default_factory=list, max_length=20)
milestone: MilestoneRef | None = Field(
default=None, description="Milestone id or title to assign the issue to"
)
class UpdateIssueArgs(RepositoryArgs):
"""Arguments for update_issue."""
issue_number: int = Field(..., ge=1)
title: str | None = Field(default=None, min_length=1, max_length=256)
body: str | None = Field(default=None, max_length=20_000)
state: Literal["open", "closed"] | None = Field(default=None)
milestone: MilestoneRef | None = Field(
default=None, description="Milestone id or title to assign; 0 clears the milestone"
)
@model_validator(mode="after")
def require_change(self) -> UpdateIssueArgs:
"""Require at least one mutable field in update payload."""
if (
self.title is None
and self.body is None
and self.state is None
and self.milestone is None
):
raise ValueError("At least one of title, body, state, or milestone must be provided")
return self
class CreateIssueCommentArgs(RepositoryArgs):
"""Arguments for create_issue_comment."""
issue_number: int = Field(..., ge=1)
body: str = Field(..., min_length=1, max_length=10_000)
class CreatePrCommentArgs(RepositoryArgs):
"""Arguments for create_pr_comment."""
pull_number: int = Field(..., ge=1)
body: str = Field(..., min_length=1, max_length=10_000)
class AddLabelsArgs(RepositoryArgs):
"""Arguments for add_labels."""
issue_number: int = Field(..., ge=1)
labels: list[str] = Field(..., min_length=1, max_length=20)
class AssignIssueArgs(RepositoryArgs):
"""Arguments for assign_issue."""
issue_number: int = Field(..., ge=1)
assignees: list[str] = Field(..., min_length=1, max_length=20)
class CreateLabelArgs(RepositoryArgs):
"""Arguments for create_label."""
name: str = Field(..., min_length=1, max_length=50)
# Gitea requires a hex color; accept it with or without a leading '#'.
color: str = Field(..., pattern=r"^#?[0-9A-Fa-f]{6}$")
description: str = Field(default="", max_length=1000)
exclusive: bool = Field(default=False)
class UpdateLabelArgs(RepositoryArgs):
"""Arguments for update_label (located by current name)."""
name: str = Field(..., min_length=1, max_length=50)
new_name: str | None = Field(default=None, min_length=1, max_length=50)
color: str | None = Field(default=None, pattern=r"^#?[0-9A-Fa-f]{6}$")
description: str | None = Field(default=None, max_length=1000)
@model_validator(mode="after")
def require_change(self) -> UpdateLabelArgs:
"""Require at least one mutable field in the update payload."""
if self.new_name is None and self.color is None and self.description is None:
raise ValueError("At least one of new_name, color, or description must be provided")
return self
class RemoveLabelsArgs(RepositoryArgs):
"""Arguments for remove_labels."""
issue_number: int = Field(..., ge=1)
labels: list[str] = Field(..., min_length=1, max_length=20)
class CreatePullRequestArgs(RepositoryArgs):
"""Arguments for create_pull_request."""
title: str = Field(..., min_length=1, max_length=256)
head: GitRef = Field(..., min_length=1, max_length=200)
base: GitRef = Field(..., min_length=1, max_length=200)
body: str = Field(default="", max_length=20_000)
class CreateReleaseArgs(RepositoryArgs):
"""Arguments for create_release."""
tag_name: GitRef = Field(..., min_length=1, max_length=200)
name: str = Field(default="", max_length=256)
body: str = Field(default="", max_length=20_000)
draft: bool = Field(default=False)
prerelease: bool = Field(default=False)
target: str | None = Field(default=None, min_length=1, max_length=200)
class EditReleaseArgs(RepositoryArgs):
"""Arguments for edit_release."""
release_id: int = Field(..., ge=1)
name: str | None = Field(default=None, max_length=256)
body: str | None = Field(default=None, max_length=20_000)
draft: bool | None = Field(default=None)
prerelease: bool | None = Field(default=None)
@model_validator(mode="after")
def require_change(self) -> EditReleaseArgs:
"""Require at least one mutable field in the update payload."""
if (
self.name is None
and self.body is None
and self.draft is None
and self.prerelease is None
):
raise ValueError("At least one of name, body, draft, or prerelease must be provided")
return self
class CreateBranchArgs(RepositoryArgs):
"""Arguments for create_branch."""
new_branch_name: GitRef = Field(..., min_length=1, max_length=200)
old_branch_name: str | None = Field(default=None, min_length=1, max_length=200)
class CreateMilestoneArgs(RepositoryArgs):
"""Arguments for create_milestone."""
title: str = Field(..., min_length=1, max_length=256)
description: str = Field(default="", max_length=10_000)
due_on: str | None = Field(default=None, max_length=64)
class EditIssueCommentArgs(RepositoryArgs):
"""Arguments for edit_issue_comment."""
comment_id: int = Field(..., ge=1)
body: str = Field(..., min_length=1, max_length=10_000)
class ListPullRequestFilesArgs(RepositoryArgs):
"""Arguments for list_pull_request_files."""
pull_number: int = Field(..., ge=1)
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class ListPullRequestCommitsArgs(RepositoryArgs):
"""Arguments for list_pull_request_commits."""
pull_number: int = Field(..., ge=1)
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class ListIssueCommentsArgs(RepositoryArgs):
"""Arguments for list_issue_comments."""
issue_number: int = Field(..., ge=1)
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class ListBranchesArgs(RepositoryArgs):
"""Arguments for list_branches."""
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class GetBranchArgs(RepositoryArgs):
"""Arguments for get_branch."""
branch: GitRef = Field(..., min_length=1, max_length=200)
class GetReleaseArgs(RepositoryArgs):
"""Arguments for get_release."""
release_id: int = Field(..., ge=1)
class LatestReleaseArgs(RepositoryArgs):
"""Arguments for get_latest_release."""
class ListMilestonesArgs(RepositoryArgs):
"""Arguments for list_milestones."""
state: Literal["open", "closed", "all"] = Field(default="open")
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class CommitStatusArgs(RepositoryArgs):
"""Arguments for get_commit_status."""
sha: GitRef = Field(..., min_length=1, max_length=64)
class ListOrgRepositoriesArgs(StrictBaseModel):
"""Arguments for list_org_repositories."""
org: str = Field(..., pattern=_REPO_PART_PATTERN)
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class ListOrganizationsArgs(StrictBaseModel):
"""Arguments for list_organizations."""
page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=50, ge=1, le=100)
class RepoLanguagesArgs(RepositoryArgs):
"""Arguments for get_repo_languages."""
class RepoTopicsArgs(RepositoryArgs):
"""Arguments for list_repo_topics."""
def extract_repository(arguments: dict[str, object]) -> str | None:
"""Extract `owner/repo` from raw argument mapping.
Args:
arguments: Raw tool arguments.
Returns:
`owner/repo` or None when arguments are incomplete.
"""
owner = arguments.get("owner")
repo = arguments.get("repo")
if isinstance(owner, str) and isinstance(repo, str) and owner and repo:
return f"{owner}/{repo}"
return None
def extract_target_path(arguments: dict[str, object]) -> str | None:
"""Extract optional target path argument for policy path checks."""
filepath = arguments.get("filepath")
if isinstance(filepath, str) and filepath:
return filepath
return None