Merge pull request 'fix: prevent path traversal via Gitea ref/sha/base/head parameters' (#18) from fix/gitea-ref-path-traversal into main
docker / test (push) Successful in 30s
docker / lint (push) Successful in 37s
test / test (push) Successful in 33s
docker / docker-test (push) Successful in 13s
lint / lint (push) Successful in 37s
docker / docker-publish (push) Successful in 6s

Reviewed-on: #18
This commit was merged in pull request #18.
This commit is contained in:
2026-06-14 14:01:58 +00:00
3 changed files with 179 additions and 30 deletions
+36 -20
View File
@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import Any
from urllib.parse import quote
from httpx import AsyncClient, Response from httpx import AsyncClient, Response
@@ -175,6 +176,8 @@ class GiteaClient:
async def get_repository(self, owner: str, repo: str) -> dict[str, Any]: async def get_repository(self, owner: str, repo: str) -> dict[str, Any]:
"""Get repository metadata.""" """Get repository metadata."""
repo_id = f"{owner}/{repo}" repo_id = f"{owner}/{repo}"
enc_owner = quote(owner, safe="")
enc_repo = quote(repo, safe="")
correlation_id = self.audit.log_tool_invocation( correlation_id = self.audit.log_tool_invocation(
tool_name="get_repository", tool_name="get_repository",
repository=repo_id, repository=repo_id,
@@ -183,7 +186,7 @@ class GiteaClient:
try: try:
result = await self._request( result = await self._request(
"GET", "GET",
f"/api/v1/repos/{owner}/{repo}", f"/api/v1/repos/{enc_owner}/{enc_repo}",
correlation_id=correlation_id, correlation_id=correlation_id,
) )
self.audit.log_tool_invocation( self.audit.log_tool_invocation(
@@ -212,6 +215,9 @@ class GiteaClient:
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Get file contents from a repository.""" """Get file contents from a repository."""
repo_id = f"{owner}/{repo}" repo_id = f"{owner}/{repo}"
enc_owner = quote(owner, safe="")
enc_repo = quote(repo, safe="")
enc_filepath = quote(filepath, safe="/")
correlation_id = self.audit.log_tool_invocation( correlation_id = self.audit.log_tool_invocation(
tool_name="get_file_contents", tool_name="get_file_contents",
repository=repo_id, repository=repo_id,
@@ -222,7 +228,7 @@ class GiteaClient:
try: try:
result = await self._request( result = await self._request(
"GET", "GET",
f"/api/v1/repos/{owner}/{repo}/contents/{filepath}", f"/api/v1/repos/{enc_owner}/{enc_repo}/contents/{enc_filepath}",
params={"ref": ref}, params={"ref": ref},
correlation_id=correlation_id, correlation_id=correlation_id,
) )
@@ -278,6 +284,9 @@ class GiteaClient:
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Get repository tree at given ref.""" """Get repository tree at given ref."""
repo_id = f"{owner}/{repo}" repo_id = f"{owner}/{repo}"
enc_owner = quote(owner, safe="")
enc_repo = quote(repo, safe="")
enc_ref = quote(ref, safe="/")
correlation_id = self.audit.log_tool_invocation( correlation_id = self.audit.log_tool_invocation(
tool_name="get_tree", tool_name="get_tree",
repository=repo_id, repository=repo_id,
@@ -287,7 +296,7 @@ class GiteaClient:
try: try:
result = await self._request( result = await self._request(
"GET", "GET",
f"/api/v1/repos/{owner}/{repo}/git/trees/{ref}", f"/api/v1/repos/{enc_owner}/{enc_repo}/git/trees/{enc_ref}",
params={"recursive": str(recursive).lower()}, params={"recursive": str(recursive).lower()},
correlation_id=correlation_id, correlation_id=correlation_id,
) )
@@ -334,7 +343,7 @@ class GiteaClient:
try: try:
result = await self._request( result = await self._request(
"GET", "GET",
f"/api/v1/repos/{owner}/{repo}/search", f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/search",
params={"q": query, "page": page, "limit": limit, "ref": ref}, params={"q": query, "page": page, "limit": limit, "ref": ref},
correlation_id=correlation_id, correlation_id=correlation_id,
) )
@@ -367,7 +376,7 @@ class GiteaClient:
"""List commits for a repository ref.""" """List commits for a repository ref."""
result = await self._request( result = await self._request(
"GET", "GET",
f"/api/v1/repos/{owner}/{repo}/commits", f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/commits",
params={"sha": ref, "page": page, "limit": limit}, params={"sha": ref, "page": page, "limit": limit},
correlation_id=str( correlation_id=str(
self.audit.log_tool_invocation(tool_name="list_commits", result_status="pending") self.audit.log_tool_invocation(tool_name="list_commits", result_status="pending")
@@ -377,9 +386,12 @@ class GiteaClient:
async def get_commit_diff(self, owner: str, repo: str, sha: str) -> dict[str, Any]: async def get_commit_diff(self, owner: str, repo: str, sha: str) -> dict[str, Any]:
"""Get detailed commit including changed files and patch metadata.""" """Get detailed commit including changed files and patch metadata."""
enc_owner = quote(owner, safe="")
enc_repo = quote(repo, safe="")
enc_sha = quote(sha, safe="/")
result = await self._request( result = await self._request(
"GET", "GET",
f"/api/v1/repos/{owner}/{repo}/git/commits/{sha}", f"/api/v1/repos/{enc_owner}/{enc_repo}/git/commits/{enc_sha}",
correlation_id=str( correlation_id=str(
self.audit.log_tool_invocation(tool_name="get_commit_diff", result_status="pending") self.audit.log_tool_invocation(tool_name="get_commit_diff", result_status="pending")
), ),
@@ -388,9 +400,13 @@ class GiteaClient:
async def compare_refs(self, owner: str, repo: str, base: str, head: str) -> dict[str, Any]: async def compare_refs(self, owner: str, repo: str, base: str, head: str) -> dict[str, Any]:
"""Compare two refs and return commit/file deltas.""" """Compare two refs and return commit/file deltas."""
enc_owner = quote(owner, safe="")
enc_repo = quote(repo, safe="")
enc_base = quote(base, safe="/")
enc_head = quote(head, safe="/")
result = await self._request( result = await self._request(
"GET", "GET",
f"/api/v1/repos/{owner}/{repo}/compare/{base}...{head}", f"/api/v1/repos/{enc_owner}/{enc_repo}/compare/{enc_base}...{enc_head}",
correlation_id=str( correlation_id=str(
self.audit.log_tool_invocation(tool_name="compare_refs", result_status="pending") self.audit.log_tool_invocation(tool_name="compare_refs", result_status="pending")
), ),
@@ -414,7 +430,7 @@ class GiteaClient:
result = await self._request( result = await self._request(
"GET", "GET",
f"/api/v1/repos/{owner}/{repo}/issues", f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues",
params=params, params=params,
correlation_id=str( correlation_id=str(
self.audit.log_tool_invocation(tool_name="list_issues", result_status="pending") self.audit.log_tool_invocation(tool_name="list_issues", result_status="pending")
@@ -426,7 +442,7 @@ class GiteaClient:
"""Get issue details.""" """Get issue details."""
result = await self._request( result = await self._request(
"GET", "GET",
f"/api/v1/repos/{owner}/{repo}/issues/{index}", f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues/{index}",
correlation_id=str( correlation_id=str(
self.audit.log_tool_invocation(tool_name="get_issue", result_status="pending") self.audit.log_tool_invocation(tool_name="get_issue", result_status="pending")
), ),
@@ -445,7 +461,7 @@ class GiteaClient:
"""List pull requests for repository.""" """List pull requests for repository."""
result = await self._request( result = await self._request(
"GET", "GET",
f"/api/v1/repos/{owner}/{repo}/pulls", f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/pulls",
params={"state": state, "page": page, "limit": limit}, params={"state": state, "page": page, "limit": limit},
correlation_id=str( correlation_id=str(
self.audit.log_tool_invocation( self.audit.log_tool_invocation(
@@ -459,7 +475,7 @@ class GiteaClient:
"""Get a single pull request.""" """Get a single pull request."""
result = await self._request( result = await self._request(
"GET", "GET",
f"/api/v1/repos/{owner}/{repo}/pulls/{index}", f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/pulls/{index}",
correlation_id=str( correlation_id=str(
self.audit.log_tool_invocation( self.audit.log_tool_invocation(
tool_name="get_pull_request", result_status="pending" tool_name="get_pull_request", result_status="pending"
@@ -474,7 +490,7 @@ class GiteaClient:
"""List repository labels.""" """List repository labels."""
result = await self._request( result = await self._request(
"GET", "GET",
f"/api/v1/repos/{owner}/{repo}/labels", f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/labels",
params={"page": page, "limit": limit}, params={"page": page, "limit": limit},
correlation_id=str( correlation_id=str(
self.audit.log_tool_invocation(tool_name="list_labels", result_status="pending") self.audit.log_tool_invocation(tool_name="list_labels", result_status="pending")
@@ -488,7 +504,7 @@ class GiteaClient:
"""List repository tags.""" """List repository tags."""
result = await self._request( result = await self._request(
"GET", "GET",
f"/api/v1/repos/{owner}/{repo}/tags", f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/tags",
params={"page": page, "limit": limit}, params={"page": page, "limit": limit},
correlation_id=str( correlation_id=str(
self.audit.log_tool_invocation(tool_name="list_tags", result_status="pending") self.audit.log_tool_invocation(tool_name="list_tags", result_status="pending")
@@ -507,7 +523,7 @@ class GiteaClient:
"""List repository releases.""" """List repository releases."""
result = await self._request( result = await self._request(
"GET", "GET",
f"/api/v1/repos/{owner}/{repo}/releases", f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/releases",
params={"page": page, "limit": limit}, params={"page": page, "limit": limit},
correlation_id=str( correlation_id=str(
self.audit.log_tool_invocation(tool_name="list_releases", result_status="pending") self.audit.log_tool_invocation(tool_name="list_releases", result_status="pending")
@@ -533,7 +549,7 @@ class GiteaClient:
payload["assignees"] = assignees payload["assignees"] = assignees
result = await self._request( result = await self._request(
"POST", "POST",
f"/api/v1/repos/{owner}/{repo}/issues", f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues",
json_body=payload, json_body=payload,
correlation_id=str( correlation_id=str(
self.audit.log_tool_invocation(tool_name="create_issue", result_status="pending") self.audit.log_tool_invocation(tool_name="create_issue", result_status="pending")
@@ -561,7 +577,7 @@ class GiteaClient:
payload["state"] = state payload["state"] = state
result = await self._request( result = await self._request(
"PATCH", "PATCH",
f"/api/v1/repos/{owner}/{repo}/issues/{index}", f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues/{index}",
json_body=payload, json_body=payload,
correlation_id=str( correlation_id=str(
self.audit.log_tool_invocation(tool_name="update_issue", result_status="pending") self.audit.log_tool_invocation(tool_name="update_issue", result_status="pending")
@@ -575,7 +591,7 @@ class GiteaClient:
"""Create a comment on issue (and PR discussion if issue index refers to PR).""" """Create a comment on issue (and PR discussion if issue index refers to PR)."""
result = await self._request( result = await self._request(
"POST", "POST",
f"/api/v1/repos/{owner}/{repo}/issues/{index}/comments", f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues/{index}/comments",
json_body={"body": body}, json_body={"body": body},
correlation_id=str( correlation_id=str(
self.audit.log_tool_invocation( self.audit.log_tool_invocation(
@@ -591,7 +607,7 @@ class GiteaClient:
"""Create PR discussion comment.""" """Create PR discussion comment."""
result = await self._request( result = await self._request(
"POST", "POST",
f"/api/v1/repos/{owner}/{repo}/issues/{index}/comments", f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues/{index}/comments",
json_body={"body": body}, json_body={"body": body},
correlation_id=str( correlation_id=str(
self.audit.log_tool_invocation( self.audit.log_tool_invocation(
@@ -611,7 +627,7 @@ class GiteaClient:
"""Add labels to issue/PR.""" """Add labels to issue/PR."""
result = await self._request( result = await self._request(
"POST", "POST",
f"/api/v1/repos/{owner}/{repo}/issues/{index}/labels", f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues/{index}/labels",
json_body={"labels": labels}, json_body={"labels": labels},
correlation_id=str( correlation_id=str(
self.audit.log_tool_invocation(tool_name="add_labels", result_status="pending") self.audit.log_tool_invocation(tool_name="add_labels", result_status="pending")
@@ -629,7 +645,7 @@ class GiteaClient:
"""Assign users to issue/PR.""" """Assign users to issue/PR."""
result = await self._request( result = await self._request(
"POST", "POST",
f"/api/v1/repos/{owner}/{repo}/issues/{index}/assignees", f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues/{index}/assignees",
json_body={"assignees": assignees}, json_body={"assignees": assignees},
correlation_id=str( correlation_id=str(
self.audit.log_tool_invocation(tool_name="assign_issue", result_status="pending") self.audit.log_tool_invocation(tool_name="assign_issue", result_status="pending")
+45 -9
View File
@@ -2,13 +2,49 @@
from __future__ import annotations from __future__ import annotations
from typing import Literal from typing import Annotated, Literal
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import AfterValidator, BaseModel, ConfigDict, Field, model_validator
_REPO_PART_PATTERN = r"^[A-Za-z0-9._-]{1,100}$" _REPO_PART_PATTERN = r"^[A-Za-z0-9._-]{1,100}$"
def _validate_git_ref(value: str) -> str:
"""Validate a git ref-like value (ref/sha/base/head) against traversal.
Refs that legitimately contain ``/`` (e.g. ``feature/foo``, ``release/1.0``)
are preserved; only traversal and unsafe URL-path characters are rejected.
Args:
value: Candidate ref, sha, base, or head value.
Returns:
The unchanged value when it is safe.
Raises:
ValueError: When the value could escape the intended repository path.
"""
# Security decision: block path traversal and absolute references.
if ".." in value.split("/"):
raise ValueError("ref must not contain '..' path segments")
if value.startswith("/"):
raise ValueError("ref must not start with '/'")
if "\\" in value:
raise ValueError("ref must not contain backslashes")
if "\x00" in value:
raise ValueError("ref must not contain null bytes")
if any(ord(char) < 0x20 or ord(char) == 0x7F for char in value):
raise ValueError("ref must not contain control characters")
if any(char.isspace() for char in value):
raise ValueError("ref must not contain whitespace")
if "?" in value or "#" in value:
raise ValueError("ref must not contain '?' or '#'")
return value
GitRef = Annotated[str, AfterValidator(_validate_git_ref)]
class StrictBaseModel(BaseModel): class StrictBaseModel(BaseModel):
"""Strict model base that rejects unexpected fields.""" """Strict model base that rejects unexpected fields."""
@@ -29,7 +65,7 @@ class RepositoryArgs(StrictBaseModel):
class FileTreeArgs(RepositoryArgs): class FileTreeArgs(RepositoryArgs):
"""Arguments for get_file_tree.""" """Arguments for get_file_tree."""
ref: str = Field(default="main", min_length=1, max_length=200) ref: GitRef = Field(default="main", min_length=1, max_length=200)
recursive: bool = Field(default=False) recursive: bool = Field(default=False)
@@ -37,7 +73,7 @@ class FileContentsArgs(RepositoryArgs):
"""Arguments for get_file_contents.""" """Arguments for get_file_contents."""
filepath: str = Field(..., min_length=1, max_length=1024) filepath: str = Field(..., min_length=1, max_length=1024)
ref: str = Field(default="main", min_length=1, max_length=200) ref: GitRef = Field(default="main", min_length=1, max_length=200)
@model_validator(mode="after") @model_validator(mode="after")
def validate_filepath(self) -> FileContentsArgs: def validate_filepath(self) -> FileContentsArgs:
@@ -55,7 +91,7 @@ class SearchCodeArgs(RepositoryArgs):
"""Arguments for search_code.""" """Arguments for search_code."""
query: str = Field(..., min_length=1, max_length=256) query: str = Field(..., min_length=1, max_length=256)
ref: str = Field(default="main", min_length=1, max_length=200) ref: GitRef = Field(default="main", min_length=1, max_length=200)
page: int = Field(default=1, ge=1, le=10_000) page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=25, ge=1, le=100) limit: int = Field(default=25, ge=1, le=100)
@@ -63,7 +99,7 @@ class SearchCodeArgs(RepositoryArgs):
class ListCommitsArgs(RepositoryArgs): class ListCommitsArgs(RepositoryArgs):
"""Arguments for list_commits.""" """Arguments for list_commits."""
ref: str = Field(default="main", min_length=1, max_length=200) ref: GitRef = Field(default="main", min_length=1, max_length=200)
page: int = Field(default=1, ge=1, le=10_000) page: int = Field(default=1, ge=1, le=10_000)
limit: int = Field(default=25, ge=1, le=100) limit: int = Field(default=25, ge=1, le=100)
@@ -71,14 +107,14 @@ class ListCommitsArgs(RepositoryArgs):
class CommitDiffArgs(RepositoryArgs): class CommitDiffArgs(RepositoryArgs):
"""Arguments for get_commit_diff.""" """Arguments for get_commit_diff."""
sha: str = Field(..., min_length=7, max_length=64) sha: GitRef = Field(..., min_length=7, max_length=64)
class CompareRefsArgs(RepositoryArgs): class CompareRefsArgs(RepositoryArgs):
"""Arguments for compare_refs.""" """Arguments for compare_refs."""
base: str = Field(..., min_length=1, max_length=200) base: GitRef = Field(..., min_length=1, max_length=200)
head: str = Field(..., min_length=1, max_length=200) head: GitRef = Field(..., min_length=1, max_length=200)
class ListIssuesArgs(RepositoryArgs): class ListIssuesArgs(RepositoryArgs):
+98 -1
View File
@@ -5,7 +5,8 @@ from __future__ import annotations
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
import pytest import pytest
from httpx import Request, Response from httpx import Client, Request, Response
from pydantic import ValidationError
from aegis_gitea_mcp.config import reset_settings from aegis_gitea_mcp.config import reset_settings
from aegis_gitea_mcp.gitea_client import ( from aegis_gitea_mcp.gitea_client import (
@@ -15,6 +16,11 @@ from aegis_gitea_mcp.gitea_client import (
GiteaError, GiteaError,
GiteaNotFoundError, GiteaNotFoundError,
) )
from aegis_gitea_mcp.tools.arguments import (
CommitDiffArgs,
CompareRefsArgs,
FileTreeArgs,
)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@@ -166,3 +172,94 @@ async def test_get_file_contents_blocks_oversized_payload(monkeypatch: pytest.Mo
with pytest.raises(GiteaError, match="exceeds limit"): with pytest.raises(GiteaError, match="exceeds limit"):
await client.get_file_contents("acme", "demo", "big.bin") await client.get_file_contents("acme", "demo", "big.bin")
_MALICIOUS_REFS = [
"../../../x/y",
"..",
"/etc/passwd",
"a\x00b",
"a?b",
"a#b",
]
@pytest.mark.parametrize("value", _MALICIOUS_REFS)
def test_file_tree_args_reject_traversal_ref(value: str) -> None:
"""Layer 1: FileTreeArgs.ref rejects traversal/unsafe values."""
with pytest.raises(ValidationError):
FileTreeArgs(owner="o", repo="r", ref=value)
@pytest.mark.parametrize("value", _MALICIOUS_REFS)
def test_commit_diff_args_reject_traversal_sha(value: str) -> None:
"""Layer 1: CommitDiffArgs.sha rejects traversal/unsafe values."""
# ".." is shorter than the 7-char min_length; still rejected (length or ref check).
with pytest.raises(ValidationError):
CommitDiffArgs(owner="o", repo="r", sha=value)
@pytest.mark.parametrize("value", _MALICIOUS_REFS)
def test_compare_refs_args_reject_traversal_base(value: str) -> None:
"""Layer 1: CompareRefsArgs.base rejects traversal/unsafe values."""
with pytest.raises(ValidationError):
CompareRefsArgs(owner="o", repo="r", base=value, head="main")
@pytest.mark.parametrize("value", _MALICIOUS_REFS)
def test_compare_refs_args_reject_traversal_head(value: str) -> None:
"""Layer 1: CompareRefsArgs.head rejects traversal/unsafe values."""
with pytest.raises(ValidationError):
CompareRefsArgs(owner="o", repo="r", base="main", head=value)
def test_git_refs_allow_slash_containing_refs() -> None:
"""Legitimate refs that contain '/' validate successfully."""
tree = FileTreeArgs(owner="o", repo="r", ref="feature/foo")
assert tree.ref == "feature/foo"
compare = CompareRefsArgs(owner="o", repo="r", base="release/1.0", head="main")
assert compare.base == "release/1.0"
assert compare.head == "main"
def test_git_ref_length_bounds_still_enforced() -> None:
"""Field min/max length still applies alongside the AfterValidator."""
with pytest.raises(ValidationError):
FileTreeArgs(owner="o", repo="r", ref="")
with pytest.raises(ValidationError):
FileTreeArgs(owner="o", repo="r", ref="a" * 201)
with pytest.raises(ValidationError):
# sha below 7-char minimum
CommitDiffArgs(owner="o", repo="r", sha="abc")
def test_layer2_encoding_confines_unsafe_chars_to_path_segment() -> None:
"""Layer 2 defense-in-depth: quote(..., safe='/') confines unsafe chars.
A sha containing ``?``, ``#``, whitespace, or a backslash that hypothetically
bypassed Layer 1 is percent-encoded so it cannot split off a query string,
fragment, or extra path component -- it stays inside the single ref segment
under the declared ``owner/repo``. (Note: ``..`` collapse is closed by Layer 1
validation, since ``quote`` intentionally leaves ``.`` and ``/`` literal to
preserve refs like ``feature/foo``.)
"""
from urllib.parse import quote
owner, repo = "alice", "repoA"
prefix = f"/api/v1/repos/{owner}/{repo}/git/commits/"
for malicious_sha in ["abc?injected=1", "abc#frag", "abc def", "abc\\..\\evil"]:
endpoint = (
f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}"
f"/git/commits/{quote(malicious_sha, safe='/')}"
)
with Client(base_url="https://gitea.example.com") as http_client:
request = http_client.build_request("GET", endpoint)
# Stays scoped to declared repo, no query/fragment broke out.
assert request.url.path.startswith(prefix)
assert request.url.query == b""
assert request.url.fragment == ""
# Exactly one ref segment after the prefix (no path splitting).
assert "/" not in request.url.path[len(prefix) :]