From 2d95e890352a59e9c3a93bd2cb59c865f453e73b Mon Sep 17 00:00:00 2001 From: latte Date: Sun, 14 Jun 2026 15:57:52 +0200 Subject: [PATCH] fix: prevent path traversal via Gitea ref/sha/base/head parameters The ref-like tool arguments (ref, sha, base, head) were only length-limited and were interpolated unencoded into Gitea API URL paths (get_tree, get_commit_diff, compare_refs). Because httpx collapses ".." path segments (RFC 3986), a crafted value such as "../../../../owner/repo/contents/secret" escaped the declared owner/repo prefix. In service-PAT mode this allowed a user authorized on one repository to read arbitrary repositories the service token could reach, and in OAuth mode it bypassed the policy engine's per-repository rules (which never see ref values). Two defense layers: - arguments.py: add _validate_git_ref / GitRef that rejects ".." path segments, leading "/", backslashes, null bytes, control chars, whitespace, and "?"/"#", while preserving legitimate slash refs (feature/foo, v1.2.3). This is what actually closes the traversal. - gitea_client.py: defense-in-depth urllib.parse.quote() on owner/repo (safe="") and ref/sha/base/head/filepath (safe="/") in every repo URL builder, mirroring the existing pattern in server.py. Tests: negative cases for traversal/unsafe chars across all four fields, positive cases for slash-containing refs, length-bound regression, and a URL-layer confinement check. Full suite green (176 passed), coverage 85.64%. Co-Authored-By: Claude Opus 4.8 --- src/aegis_gitea_mcp/gitea_client.py | 56 +++++++++------ src/aegis_gitea_mcp/tools/arguments.py | 54 +++++++++++--- tests/test_gitea_client.py | 99 +++++++++++++++++++++++++- 3 files changed, 179 insertions(+), 30 deletions(-) diff --git a/src/aegis_gitea_mcp/gitea_client.py b/src/aegis_gitea_mcp/gitea_client.py index 0536047..50677e6 100644 --- a/src/aegis_gitea_mcp/gitea_client.py +++ b/src/aegis_gitea_mcp/gitea_client.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import Any +from urllib.parse import quote from httpx import AsyncClient, Response @@ -175,6 +176,8 @@ class GiteaClient: async def get_repository(self, owner: str, repo: str) -> dict[str, Any]: """Get repository metadata.""" repo_id = f"{owner}/{repo}" + enc_owner = quote(owner, safe="") + enc_repo = quote(repo, safe="") correlation_id = self.audit.log_tool_invocation( tool_name="get_repository", repository=repo_id, @@ -183,7 +186,7 @@ class GiteaClient: try: result = await self._request( "GET", - f"/api/v1/repos/{owner}/{repo}", + f"/api/v1/repos/{enc_owner}/{enc_repo}", correlation_id=correlation_id, ) self.audit.log_tool_invocation( @@ -212,6 +215,9 @@ class GiteaClient: ) -> dict[str, Any]: """Get file contents from a repository.""" repo_id = f"{owner}/{repo}" + enc_owner = quote(owner, safe="") + enc_repo = quote(repo, safe="") + enc_filepath = quote(filepath, safe="/") correlation_id = self.audit.log_tool_invocation( tool_name="get_file_contents", repository=repo_id, @@ -222,7 +228,7 @@ class GiteaClient: try: result = await self._request( "GET", - f"/api/v1/repos/{owner}/{repo}/contents/{filepath}", + f"/api/v1/repos/{enc_owner}/{enc_repo}/contents/{enc_filepath}", params={"ref": ref}, correlation_id=correlation_id, ) @@ -278,6 +284,9 @@ class GiteaClient: ) -> dict[str, Any]: """Get repository tree at given ref.""" repo_id = f"{owner}/{repo}" + enc_owner = quote(owner, safe="") + enc_repo = quote(repo, safe="") + enc_ref = quote(ref, safe="/") correlation_id = self.audit.log_tool_invocation( tool_name="get_tree", repository=repo_id, @@ -287,7 +296,7 @@ class GiteaClient: try: result = await self._request( "GET", - f"/api/v1/repos/{owner}/{repo}/git/trees/{ref}", + f"/api/v1/repos/{enc_owner}/{enc_repo}/git/trees/{enc_ref}", params={"recursive": str(recursive).lower()}, correlation_id=correlation_id, ) @@ -334,7 +343,7 @@ class GiteaClient: try: result = await self._request( "GET", - f"/api/v1/repos/{owner}/{repo}/search", + f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/search", params={"q": query, "page": page, "limit": limit, "ref": ref}, correlation_id=correlation_id, ) @@ -367,7 +376,7 @@ class GiteaClient: """List commits for a repository ref.""" result = await self._request( "GET", - f"/api/v1/repos/{owner}/{repo}/commits", + f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/commits", params={"sha": ref, "page": page, "limit": limit}, correlation_id=str( self.audit.log_tool_invocation(tool_name="list_commits", result_status="pending") @@ -377,9 +386,12 @@ class GiteaClient: async def get_commit_diff(self, owner: str, repo: str, sha: str) -> dict[str, Any]: """Get detailed commit including changed files and patch metadata.""" + enc_owner = quote(owner, safe="") + enc_repo = quote(repo, safe="") + enc_sha = quote(sha, safe="/") result = await self._request( "GET", - f"/api/v1/repos/{owner}/{repo}/git/commits/{sha}", + f"/api/v1/repos/{enc_owner}/{enc_repo}/git/commits/{enc_sha}", correlation_id=str( self.audit.log_tool_invocation(tool_name="get_commit_diff", result_status="pending") ), @@ -388,9 +400,13 @@ class GiteaClient: async def compare_refs(self, owner: str, repo: str, base: str, head: str) -> dict[str, Any]: """Compare two refs and return commit/file deltas.""" + enc_owner = quote(owner, safe="") + enc_repo = quote(repo, safe="") + enc_base = quote(base, safe="/") + enc_head = quote(head, safe="/") result = await self._request( "GET", - f"/api/v1/repos/{owner}/{repo}/compare/{base}...{head}", + f"/api/v1/repos/{enc_owner}/{enc_repo}/compare/{enc_base}...{enc_head}", correlation_id=str( self.audit.log_tool_invocation(tool_name="compare_refs", result_status="pending") ), @@ -414,7 +430,7 @@ class GiteaClient: result = await self._request( "GET", - f"/api/v1/repos/{owner}/{repo}/issues", + f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues", params=params, correlation_id=str( self.audit.log_tool_invocation(tool_name="list_issues", result_status="pending") @@ -426,7 +442,7 @@ class GiteaClient: """Get issue details.""" result = await self._request( "GET", - f"/api/v1/repos/{owner}/{repo}/issues/{index}", + f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues/{index}", correlation_id=str( self.audit.log_tool_invocation(tool_name="get_issue", result_status="pending") ), @@ -445,7 +461,7 @@ class GiteaClient: """List pull requests for repository.""" result = await self._request( "GET", - f"/api/v1/repos/{owner}/{repo}/pulls", + f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/pulls", params={"state": state, "page": page, "limit": limit}, correlation_id=str( self.audit.log_tool_invocation( @@ -459,7 +475,7 @@ class GiteaClient: """Get a single pull request.""" result = await self._request( "GET", - f"/api/v1/repos/{owner}/{repo}/pulls/{index}", + f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/pulls/{index}", correlation_id=str( self.audit.log_tool_invocation( tool_name="get_pull_request", result_status="pending" @@ -474,7 +490,7 @@ class GiteaClient: """List repository labels.""" result = await self._request( "GET", - f"/api/v1/repos/{owner}/{repo}/labels", + f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/labels", params={"page": page, "limit": limit}, correlation_id=str( self.audit.log_tool_invocation(tool_name="list_labels", result_status="pending") @@ -488,7 +504,7 @@ class GiteaClient: """List repository tags.""" result = await self._request( "GET", - f"/api/v1/repos/{owner}/{repo}/tags", + f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/tags", params={"page": page, "limit": limit}, correlation_id=str( self.audit.log_tool_invocation(tool_name="list_tags", result_status="pending") @@ -507,7 +523,7 @@ class GiteaClient: """List repository releases.""" result = await self._request( "GET", - f"/api/v1/repos/{owner}/{repo}/releases", + f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/releases", params={"page": page, "limit": limit}, correlation_id=str( self.audit.log_tool_invocation(tool_name="list_releases", result_status="pending") @@ -533,7 +549,7 @@ class GiteaClient: payload["assignees"] = assignees result = await self._request( "POST", - f"/api/v1/repos/{owner}/{repo}/issues", + f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues", json_body=payload, correlation_id=str( self.audit.log_tool_invocation(tool_name="create_issue", result_status="pending") @@ -561,7 +577,7 @@ class GiteaClient: payload["state"] = state result = await self._request( "PATCH", - f"/api/v1/repos/{owner}/{repo}/issues/{index}", + f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues/{index}", json_body=payload, correlation_id=str( self.audit.log_tool_invocation(tool_name="update_issue", result_status="pending") @@ -575,7 +591,7 @@ class GiteaClient: """Create a comment on issue (and PR discussion if issue index refers to PR).""" result = await self._request( "POST", - f"/api/v1/repos/{owner}/{repo}/issues/{index}/comments", + f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues/{index}/comments", json_body={"body": body}, correlation_id=str( self.audit.log_tool_invocation( @@ -591,7 +607,7 @@ class GiteaClient: """Create PR discussion comment.""" result = await self._request( "POST", - f"/api/v1/repos/{owner}/{repo}/issues/{index}/comments", + f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues/{index}/comments", json_body={"body": body}, correlation_id=str( self.audit.log_tool_invocation( @@ -611,7 +627,7 @@ class GiteaClient: """Add labels to issue/PR.""" result = await self._request( "POST", - f"/api/v1/repos/{owner}/{repo}/issues/{index}/labels", + f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues/{index}/labels", json_body={"labels": labels}, correlation_id=str( self.audit.log_tool_invocation(tool_name="add_labels", result_status="pending") @@ -629,7 +645,7 @@ class GiteaClient: """Assign users to issue/PR.""" result = await self._request( "POST", - f"/api/v1/repos/{owner}/{repo}/issues/{index}/assignees", + f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues/{index}/assignees", json_body={"assignees": assignees}, correlation_id=str( self.audit.log_tool_invocation(tool_name="assign_issue", result_status="pending") diff --git a/src/aegis_gitea_mcp/tools/arguments.py b/src/aegis_gitea_mcp/tools/arguments.py index e03ffe6..ec573aa 100644 --- a/src/aegis_gitea_mcp/tools/arguments.py +++ b/src/aegis_gitea_mcp/tools/arguments.py @@ -2,13 +2,49 @@ from __future__ import annotations -from typing import Literal +from typing import Annotated, Literal -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import AfterValidator, BaseModel, ConfigDict, Field, model_validator _REPO_PART_PATTERN = r"^[A-Za-z0-9._-]{1,100}$" +def _validate_git_ref(value: str) -> str: + """Validate a git ref-like value (ref/sha/base/head) against traversal. + + Refs that legitimately contain ``/`` (e.g. ``feature/foo``, ``release/1.0``) + are preserved; only traversal and unsafe URL-path characters are rejected. + + Args: + value: Candidate ref, sha, base, or head value. + + Returns: + The unchanged value when it is safe. + + Raises: + ValueError: When the value could escape the intended repository path. + """ + # Security decision: block path traversal and absolute references. + if ".." in value.split("/"): + raise ValueError("ref must not contain '..' path segments") + if value.startswith("/"): + raise ValueError("ref must not start with '/'") + if "\\" in value: + raise ValueError("ref must not contain backslashes") + if "\x00" in value: + raise ValueError("ref must not contain null bytes") + if any(ord(char) < 0x20 or ord(char) == 0x7F for char in value): + raise ValueError("ref must not contain control characters") + if any(char.isspace() for char in value): + raise ValueError("ref must not contain whitespace") + if "?" in value or "#" in value: + raise ValueError("ref must not contain '?' or '#'") + return value + + +GitRef = Annotated[str, AfterValidator(_validate_git_ref)] + + class StrictBaseModel(BaseModel): """Strict model base that rejects unexpected fields.""" @@ -29,7 +65,7 @@ class RepositoryArgs(StrictBaseModel): class FileTreeArgs(RepositoryArgs): """Arguments for get_file_tree.""" - ref: str = Field(default="main", min_length=1, max_length=200) + ref: GitRef = Field(default="main", min_length=1, max_length=200) recursive: bool = Field(default=False) @@ -37,7 +73,7 @@ class FileContentsArgs(RepositoryArgs): """Arguments for get_file_contents.""" filepath: str = Field(..., min_length=1, max_length=1024) - ref: str = Field(default="main", min_length=1, max_length=200) + ref: GitRef = Field(default="main", min_length=1, max_length=200) @model_validator(mode="after") def validate_filepath(self) -> FileContentsArgs: @@ -55,7 +91,7 @@ class SearchCodeArgs(RepositoryArgs): """Arguments for search_code.""" query: str = Field(..., min_length=1, max_length=256) - ref: str = Field(default="main", min_length=1, max_length=200) + ref: GitRef = Field(default="main", min_length=1, max_length=200) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=25, ge=1, le=100) @@ -63,7 +99,7 @@ class SearchCodeArgs(RepositoryArgs): class ListCommitsArgs(RepositoryArgs): """Arguments for list_commits.""" - ref: str = Field(default="main", min_length=1, max_length=200) + ref: GitRef = Field(default="main", min_length=1, max_length=200) page: int = Field(default=1, ge=1, le=10_000) limit: int = Field(default=25, ge=1, le=100) @@ -71,14 +107,14 @@ class ListCommitsArgs(RepositoryArgs): class CommitDiffArgs(RepositoryArgs): """Arguments for get_commit_diff.""" - sha: str = Field(..., min_length=7, max_length=64) + sha: GitRef = Field(..., min_length=7, max_length=64) class CompareRefsArgs(RepositoryArgs): """Arguments for compare_refs.""" - base: str = Field(..., min_length=1, max_length=200) - head: str = Field(..., min_length=1, max_length=200) + base: GitRef = Field(..., min_length=1, max_length=200) + head: GitRef = Field(..., min_length=1, max_length=200) class ListIssuesArgs(RepositoryArgs): diff --git a/tests/test_gitea_client.py b/tests/test_gitea_client.py index 53750fb..e2fddd9 100644 --- a/tests/test_gitea_client.py +++ b/tests/test_gitea_client.py @@ -5,7 +5,8 @@ from __future__ import annotations from unittest.mock import AsyncMock, patch import pytest -from httpx import Request, Response +from httpx import Client, Request, Response +from pydantic import ValidationError from aegis_gitea_mcp.config import reset_settings from aegis_gitea_mcp.gitea_client import ( @@ -15,6 +16,11 @@ from aegis_gitea_mcp.gitea_client import ( GiteaError, GiteaNotFoundError, ) +from aegis_gitea_mcp.tools.arguments import ( + CommitDiffArgs, + CompareRefsArgs, + FileTreeArgs, +) @pytest.fixture(autouse=True) @@ -166,3 +172,94 @@ async def test_get_file_contents_blocks_oversized_payload(monkeypatch: pytest.Mo with pytest.raises(GiteaError, match="exceeds limit"): await client.get_file_contents("acme", "demo", "big.bin") + + +_MALICIOUS_REFS = [ + "../../../x/y", + "..", + "/etc/passwd", + "a\x00b", + "a?b", + "a#b", +] + + +@pytest.mark.parametrize("value", _MALICIOUS_REFS) +def test_file_tree_args_reject_traversal_ref(value: str) -> None: + """Layer 1: FileTreeArgs.ref rejects traversal/unsafe values.""" + with pytest.raises(ValidationError): + FileTreeArgs(owner="o", repo="r", ref=value) + + +@pytest.mark.parametrize("value", _MALICIOUS_REFS) +def test_commit_diff_args_reject_traversal_sha(value: str) -> None: + """Layer 1: CommitDiffArgs.sha rejects traversal/unsafe values.""" + # ".." is shorter than the 7-char min_length; still rejected (length or ref check). + with pytest.raises(ValidationError): + CommitDiffArgs(owner="o", repo="r", sha=value) + + +@pytest.mark.parametrize("value", _MALICIOUS_REFS) +def test_compare_refs_args_reject_traversal_base(value: str) -> None: + """Layer 1: CompareRefsArgs.base rejects traversal/unsafe values.""" + with pytest.raises(ValidationError): + CompareRefsArgs(owner="o", repo="r", base=value, head="main") + + +@pytest.mark.parametrize("value", _MALICIOUS_REFS) +def test_compare_refs_args_reject_traversal_head(value: str) -> None: + """Layer 1: CompareRefsArgs.head rejects traversal/unsafe values.""" + with pytest.raises(ValidationError): + CompareRefsArgs(owner="o", repo="r", base="main", head=value) + + +def test_git_refs_allow_slash_containing_refs() -> None: + """Legitimate refs that contain '/' validate successfully.""" + tree = FileTreeArgs(owner="o", repo="r", ref="feature/foo") + assert tree.ref == "feature/foo" + + compare = CompareRefsArgs(owner="o", repo="r", base="release/1.0", head="main") + assert compare.base == "release/1.0" + assert compare.head == "main" + + +def test_git_ref_length_bounds_still_enforced() -> None: + """Field min/max length still applies alongside the AfterValidator.""" + with pytest.raises(ValidationError): + FileTreeArgs(owner="o", repo="r", ref="") + with pytest.raises(ValidationError): + FileTreeArgs(owner="o", repo="r", ref="a" * 201) + with pytest.raises(ValidationError): + # sha below 7-char minimum + CommitDiffArgs(owner="o", repo="r", sha="abc") + + +def test_layer2_encoding_confines_unsafe_chars_to_path_segment() -> None: + """Layer 2 defense-in-depth: quote(..., safe='/') confines unsafe chars. + + A sha containing ``?``, ``#``, whitespace, or a backslash that hypothetically + bypassed Layer 1 is percent-encoded so it cannot split off a query string, + fragment, or extra path component -- it stays inside the single ref segment + under the declared ``owner/repo``. (Note: ``..`` collapse is closed by Layer 1 + validation, since ``quote`` intentionally leaves ``.`` and ``/`` literal to + preserve refs like ``feature/foo``.) + """ + from urllib.parse import quote + + owner, repo = "alice", "repoA" + prefix = f"/api/v1/repos/{owner}/{repo}/git/commits/" + + for malicious_sha in ["abc?injected=1", "abc#frag", "abc def", "abc\\..\\evil"]: + endpoint = ( + f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}" + f"/git/commits/{quote(malicious_sha, safe='/')}" + ) + with Client(base_url="https://gitea.example.com") as http_client: + request = http_client.build_request("GET", endpoint) + + # Stays scoped to declared repo, no query/fragment broke out. + assert request.url.path.startswith(prefix) + assert request.url.query == b"" + assert request.url.fragment == "" + # Exactly one ref segment after the prefix (no path splitting). + assert "/" not in request.url.path[len(prefix) :]