Merge pull request 'fix: prevent path traversal via Gitea ref/sha/base/head parameters' (#18) from fix/gitea-ref-path-traversal into main
Reviewed-on: #18
This commit was merged in pull request #18.
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
from httpx import AsyncClient, Response
|
from httpx import AsyncClient, Response
|
||||||
|
|
||||||
@@ -175,6 +176,8 @@ class GiteaClient:
|
|||||||
async def get_repository(self, owner: str, repo: str) -> dict[str, Any]:
|
async def get_repository(self, owner: str, repo: str) -> dict[str, Any]:
|
||||||
"""Get repository metadata."""
|
"""Get repository metadata."""
|
||||||
repo_id = f"{owner}/{repo}"
|
repo_id = f"{owner}/{repo}"
|
||||||
|
enc_owner = quote(owner, safe="")
|
||||||
|
enc_repo = quote(repo, safe="")
|
||||||
correlation_id = self.audit.log_tool_invocation(
|
correlation_id = self.audit.log_tool_invocation(
|
||||||
tool_name="get_repository",
|
tool_name="get_repository",
|
||||||
repository=repo_id,
|
repository=repo_id,
|
||||||
@@ -183,7 +186,7 @@ class GiteaClient:
|
|||||||
try:
|
try:
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/api/v1/repos/{owner}/{repo}",
|
f"/api/v1/repos/{enc_owner}/{enc_repo}",
|
||||||
correlation_id=correlation_id,
|
correlation_id=correlation_id,
|
||||||
)
|
)
|
||||||
self.audit.log_tool_invocation(
|
self.audit.log_tool_invocation(
|
||||||
@@ -212,6 +215,9 @@ class GiteaClient:
|
|||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Get file contents from a repository."""
|
"""Get file contents from a repository."""
|
||||||
repo_id = f"{owner}/{repo}"
|
repo_id = f"{owner}/{repo}"
|
||||||
|
enc_owner = quote(owner, safe="")
|
||||||
|
enc_repo = quote(repo, safe="")
|
||||||
|
enc_filepath = quote(filepath, safe="/")
|
||||||
correlation_id = self.audit.log_tool_invocation(
|
correlation_id = self.audit.log_tool_invocation(
|
||||||
tool_name="get_file_contents",
|
tool_name="get_file_contents",
|
||||||
repository=repo_id,
|
repository=repo_id,
|
||||||
@@ -222,7 +228,7 @@ class GiteaClient:
|
|||||||
try:
|
try:
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/api/v1/repos/{owner}/{repo}/contents/{filepath}",
|
f"/api/v1/repos/{enc_owner}/{enc_repo}/contents/{enc_filepath}",
|
||||||
params={"ref": ref},
|
params={"ref": ref},
|
||||||
correlation_id=correlation_id,
|
correlation_id=correlation_id,
|
||||||
)
|
)
|
||||||
@@ -278,6 +284,9 @@ class GiteaClient:
|
|||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Get repository tree at given ref."""
|
"""Get repository tree at given ref."""
|
||||||
repo_id = f"{owner}/{repo}"
|
repo_id = f"{owner}/{repo}"
|
||||||
|
enc_owner = quote(owner, safe="")
|
||||||
|
enc_repo = quote(repo, safe="")
|
||||||
|
enc_ref = quote(ref, safe="/")
|
||||||
correlation_id = self.audit.log_tool_invocation(
|
correlation_id = self.audit.log_tool_invocation(
|
||||||
tool_name="get_tree",
|
tool_name="get_tree",
|
||||||
repository=repo_id,
|
repository=repo_id,
|
||||||
@@ -287,7 +296,7 @@ class GiteaClient:
|
|||||||
try:
|
try:
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/api/v1/repos/{owner}/{repo}/git/trees/{ref}",
|
f"/api/v1/repos/{enc_owner}/{enc_repo}/git/trees/{enc_ref}",
|
||||||
params={"recursive": str(recursive).lower()},
|
params={"recursive": str(recursive).lower()},
|
||||||
correlation_id=correlation_id,
|
correlation_id=correlation_id,
|
||||||
)
|
)
|
||||||
@@ -334,7 +343,7 @@ class GiteaClient:
|
|||||||
try:
|
try:
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/api/v1/repos/{owner}/{repo}/search",
|
f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/search",
|
||||||
params={"q": query, "page": page, "limit": limit, "ref": ref},
|
params={"q": query, "page": page, "limit": limit, "ref": ref},
|
||||||
correlation_id=correlation_id,
|
correlation_id=correlation_id,
|
||||||
)
|
)
|
||||||
@@ -367,7 +376,7 @@ class GiteaClient:
|
|||||||
"""List commits for a repository ref."""
|
"""List commits for a repository ref."""
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/api/v1/repos/{owner}/{repo}/commits",
|
f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/commits",
|
||||||
params={"sha": ref, "page": page, "limit": limit},
|
params={"sha": ref, "page": page, "limit": limit},
|
||||||
correlation_id=str(
|
correlation_id=str(
|
||||||
self.audit.log_tool_invocation(tool_name="list_commits", result_status="pending")
|
self.audit.log_tool_invocation(tool_name="list_commits", result_status="pending")
|
||||||
@@ -377,9 +386,12 @@ class GiteaClient:
|
|||||||
|
|
||||||
async def get_commit_diff(self, owner: str, repo: str, sha: str) -> dict[str, Any]:
|
async def get_commit_diff(self, owner: str, repo: str, sha: str) -> dict[str, Any]:
|
||||||
"""Get detailed commit including changed files and patch metadata."""
|
"""Get detailed commit including changed files and patch metadata."""
|
||||||
|
enc_owner = quote(owner, safe="")
|
||||||
|
enc_repo = quote(repo, safe="")
|
||||||
|
enc_sha = quote(sha, safe="/")
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/api/v1/repos/{owner}/{repo}/git/commits/{sha}",
|
f"/api/v1/repos/{enc_owner}/{enc_repo}/git/commits/{enc_sha}",
|
||||||
correlation_id=str(
|
correlation_id=str(
|
||||||
self.audit.log_tool_invocation(tool_name="get_commit_diff", result_status="pending")
|
self.audit.log_tool_invocation(tool_name="get_commit_diff", result_status="pending")
|
||||||
),
|
),
|
||||||
@@ -388,9 +400,13 @@ class GiteaClient:
|
|||||||
|
|
||||||
async def compare_refs(self, owner: str, repo: str, base: str, head: str) -> dict[str, Any]:
|
async def compare_refs(self, owner: str, repo: str, base: str, head: str) -> dict[str, Any]:
|
||||||
"""Compare two refs and return commit/file deltas."""
|
"""Compare two refs and return commit/file deltas."""
|
||||||
|
enc_owner = quote(owner, safe="")
|
||||||
|
enc_repo = quote(repo, safe="")
|
||||||
|
enc_base = quote(base, safe="/")
|
||||||
|
enc_head = quote(head, safe="/")
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/api/v1/repos/{owner}/{repo}/compare/{base}...{head}",
|
f"/api/v1/repos/{enc_owner}/{enc_repo}/compare/{enc_base}...{enc_head}",
|
||||||
correlation_id=str(
|
correlation_id=str(
|
||||||
self.audit.log_tool_invocation(tool_name="compare_refs", result_status="pending")
|
self.audit.log_tool_invocation(tool_name="compare_refs", result_status="pending")
|
||||||
),
|
),
|
||||||
@@ -414,7 +430,7 @@ class GiteaClient:
|
|||||||
|
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/api/v1/repos/{owner}/{repo}/issues",
|
f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues",
|
||||||
params=params,
|
params=params,
|
||||||
correlation_id=str(
|
correlation_id=str(
|
||||||
self.audit.log_tool_invocation(tool_name="list_issues", result_status="pending")
|
self.audit.log_tool_invocation(tool_name="list_issues", result_status="pending")
|
||||||
@@ -426,7 +442,7 @@ class GiteaClient:
|
|||||||
"""Get issue details."""
|
"""Get issue details."""
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/api/v1/repos/{owner}/{repo}/issues/{index}",
|
f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues/{index}",
|
||||||
correlation_id=str(
|
correlation_id=str(
|
||||||
self.audit.log_tool_invocation(tool_name="get_issue", result_status="pending")
|
self.audit.log_tool_invocation(tool_name="get_issue", result_status="pending")
|
||||||
),
|
),
|
||||||
@@ -445,7 +461,7 @@ class GiteaClient:
|
|||||||
"""List pull requests for repository."""
|
"""List pull requests for repository."""
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/api/v1/repos/{owner}/{repo}/pulls",
|
f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/pulls",
|
||||||
params={"state": state, "page": page, "limit": limit},
|
params={"state": state, "page": page, "limit": limit},
|
||||||
correlation_id=str(
|
correlation_id=str(
|
||||||
self.audit.log_tool_invocation(
|
self.audit.log_tool_invocation(
|
||||||
@@ -459,7 +475,7 @@ class GiteaClient:
|
|||||||
"""Get a single pull request."""
|
"""Get a single pull request."""
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/api/v1/repos/{owner}/{repo}/pulls/{index}",
|
f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/pulls/{index}",
|
||||||
correlation_id=str(
|
correlation_id=str(
|
||||||
self.audit.log_tool_invocation(
|
self.audit.log_tool_invocation(
|
||||||
tool_name="get_pull_request", result_status="pending"
|
tool_name="get_pull_request", result_status="pending"
|
||||||
@@ -474,7 +490,7 @@ class GiteaClient:
|
|||||||
"""List repository labels."""
|
"""List repository labels."""
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/api/v1/repos/{owner}/{repo}/labels",
|
f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/labels",
|
||||||
params={"page": page, "limit": limit},
|
params={"page": page, "limit": limit},
|
||||||
correlation_id=str(
|
correlation_id=str(
|
||||||
self.audit.log_tool_invocation(tool_name="list_labels", result_status="pending")
|
self.audit.log_tool_invocation(tool_name="list_labels", result_status="pending")
|
||||||
@@ -488,7 +504,7 @@ class GiteaClient:
|
|||||||
"""List repository tags."""
|
"""List repository tags."""
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/api/v1/repos/{owner}/{repo}/tags",
|
f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/tags",
|
||||||
params={"page": page, "limit": limit},
|
params={"page": page, "limit": limit},
|
||||||
correlation_id=str(
|
correlation_id=str(
|
||||||
self.audit.log_tool_invocation(tool_name="list_tags", result_status="pending")
|
self.audit.log_tool_invocation(tool_name="list_tags", result_status="pending")
|
||||||
@@ -507,7 +523,7 @@ class GiteaClient:
|
|||||||
"""List repository releases."""
|
"""List repository releases."""
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/api/v1/repos/{owner}/{repo}/releases",
|
f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/releases",
|
||||||
params={"page": page, "limit": limit},
|
params={"page": page, "limit": limit},
|
||||||
correlation_id=str(
|
correlation_id=str(
|
||||||
self.audit.log_tool_invocation(tool_name="list_releases", result_status="pending")
|
self.audit.log_tool_invocation(tool_name="list_releases", result_status="pending")
|
||||||
@@ -533,7 +549,7 @@ class GiteaClient:
|
|||||||
payload["assignees"] = assignees
|
payload["assignees"] = assignees
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"POST",
|
"POST",
|
||||||
f"/api/v1/repos/{owner}/{repo}/issues",
|
f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues",
|
||||||
json_body=payload,
|
json_body=payload,
|
||||||
correlation_id=str(
|
correlation_id=str(
|
||||||
self.audit.log_tool_invocation(tool_name="create_issue", result_status="pending")
|
self.audit.log_tool_invocation(tool_name="create_issue", result_status="pending")
|
||||||
@@ -561,7 +577,7 @@ class GiteaClient:
|
|||||||
payload["state"] = state
|
payload["state"] = state
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"PATCH",
|
"PATCH",
|
||||||
f"/api/v1/repos/{owner}/{repo}/issues/{index}",
|
f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues/{index}",
|
||||||
json_body=payload,
|
json_body=payload,
|
||||||
correlation_id=str(
|
correlation_id=str(
|
||||||
self.audit.log_tool_invocation(tool_name="update_issue", result_status="pending")
|
self.audit.log_tool_invocation(tool_name="update_issue", result_status="pending")
|
||||||
@@ -575,7 +591,7 @@ class GiteaClient:
|
|||||||
"""Create a comment on issue (and PR discussion if issue index refers to PR)."""
|
"""Create a comment on issue (and PR discussion if issue index refers to PR)."""
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"POST",
|
"POST",
|
||||||
f"/api/v1/repos/{owner}/{repo}/issues/{index}/comments",
|
f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues/{index}/comments",
|
||||||
json_body={"body": body},
|
json_body={"body": body},
|
||||||
correlation_id=str(
|
correlation_id=str(
|
||||||
self.audit.log_tool_invocation(
|
self.audit.log_tool_invocation(
|
||||||
@@ -591,7 +607,7 @@ class GiteaClient:
|
|||||||
"""Create PR discussion comment."""
|
"""Create PR discussion comment."""
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"POST",
|
"POST",
|
||||||
f"/api/v1/repos/{owner}/{repo}/issues/{index}/comments",
|
f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues/{index}/comments",
|
||||||
json_body={"body": body},
|
json_body={"body": body},
|
||||||
correlation_id=str(
|
correlation_id=str(
|
||||||
self.audit.log_tool_invocation(
|
self.audit.log_tool_invocation(
|
||||||
@@ -611,7 +627,7 @@ class GiteaClient:
|
|||||||
"""Add labels to issue/PR."""
|
"""Add labels to issue/PR."""
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"POST",
|
"POST",
|
||||||
f"/api/v1/repos/{owner}/{repo}/issues/{index}/labels",
|
f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues/{index}/labels",
|
||||||
json_body={"labels": labels},
|
json_body={"labels": labels},
|
||||||
correlation_id=str(
|
correlation_id=str(
|
||||||
self.audit.log_tool_invocation(tool_name="add_labels", result_status="pending")
|
self.audit.log_tool_invocation(tool_name="add_labels", result_status="pending")
|
||||||
@@ -629,7 +645,7 @@ class GiteaClient:
|
|||||||
"""Assign users to issue/PR."""
|
"""Assign users to issue/PR."""
|
||||||
result = await self._request(
|
result = await self._request(
|
||||||
"POST",
|
"POST",
|
||||||
f"/api/v1/repos/{owner}/{repo}/issues/{index}/assignees",
|
f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}/issues/{index}/assignees",
|
||||||
json_body={"assignees": assignees},
|
json_body={"assignees": assignees},
|
||||||
correlation_id=str(
|
correlation_id=str(
|
||||||
self.audit.log_tool_invocation(tool_name="assign_issue", result_status="pending")
|
self.audit.log_tool_invocation(tool_name="assign_issue", result_status="pending")
|
||||||
|
|||||||
@@ -2,13 +2,49 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
from pydantic import AfterValidator, BaseModel, ConfigDict, Field, model_validator
|
||||||
|
|
||||||
_REPO_PART_PATTERN = r"^[A-Za-z0-9._-]{1,100}$"
|
_REPO_PART_PATTERN = r"^[A-Za-z0-9._-]{1,100}$"
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_git_ref(value: str) -> str:
|
||||||
|
"""Validate a git ref-like value (ref/sha/base/head) against traversal.
|
||||||
|
|
||||||
|
Refs that legitimately contain ``/`` (e.g. ``feature/foo``, ``release/1.0``)
|
||||||
|
are preserved; only traversal and unsafe URL-path characters are rejected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Candidate ref, sha, base, or head value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The unchanged value when it is safe.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: When the value could escape the intended repository path.
|
||||||
|
"""
|
||||||
|
# Security decision: block path traversal and absolute references.
|
||||||
|
if ".." in value.split("/"):
|
||||||
|
raise ValueError("ref must not contain '..' path segments")
|
||||||
|
if value.startswith("/"):
|
||||||
|
raise ValueError("ref must not start with '/'")
|
||||||
|
if "\\" in value:
|
||||||
|
raise ValueError("ref must not contain backslashes")
|
||||||
|
if "\x00" in value:
|
||||||
|
raise ValueError("ref must not contain null bytes")
|
||||||
|
if any(ord(char) < 0x20 or ord(char) == 0x7F for char in value):
|
||||||
|
raise ValueError("ref must not contain control characters")
|
||||||
|
if any(char.isspace() for char in value):
|
||||||
|
raise ValueError("ref must not contain whitespace")
|
||||||
|
if "?" in value or "#" in value:
|
||||||
|
raise ValueError("ref must not contain '?' or '#'")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
GitRef = Annotated[str, AfterValidator(_validate_git_ref)]
|
||||||
|
|
||||||
|
|
||||||
class StrictBaseModel(BaseModel):
|
class StrictBaseModel(BaseModel):
|
||||||
"""Strict model base that rejects unexpected fields."""
|
"""Strict model base that rejects unexpected fields."""
|
||||||
|
|
||||||
@@ -29,7 +65,7 @@ class RepositoryArgs(StrictBaseModel):
|
|||||||
class FileTreeArgs(RepositoryArgs):
|
class FileTreeArgs(RepositoryArgs):
|
||||||
"""Arguments for get_file_tree."""
|
"""Arguments for get_file_tree."""
|
||||||
|
|
||||||
ref: str = Field(default="main", min_length=1, max_length=200)
|
ref: GitRef = Field(default="main", min_length=1, max_length=200)
|
||||||
recursive: bool = Field(default=False)
|
recursive: bool = Field(default=False)
|
||||||
|
|
||||||
|
|
||||||
@@ -37,7 +73,7 @@ class FileContentsArgs(RepositoryArgs):
|
|||||||
"""Arguments for get_file_contents."""
|
"""Arguments for get_file_contents."""
|
||||||
|
|
||||||
filepath: str = Field(..., min_length=1, max_length=1024)
|
filepath: str = Field(..., min_length=1, max_length=1024)
|
||||||
ref: str = Field(default="main", min_length=1, max_length=200)
|
ref: GitRef = Field(default="main", min_length=1, max_length=200)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_filepath(self) -> FileContentsArgs:
|
def validate_filepath(self) -> FileContentsArgs:
|
||||||
@@ -55,7 +91,7 @@ class SearchCodeArgs(RepositoryArgs):
|
|||||||
"""Arguments for search_code."""
|
"""Arguments for search_code."""
|
||||||
|
|
||||||
query: str = Field(..., min_length=1, max_length=256)
|
query: str = Field(..., min_length=1, max_length=256)
|
||||||
ref: str = Field(default="main", min_length=1, max_length=200)
|
ref: GitRef = Field(default="main", min_length=1, max_length=200)
|
||||||
page: int = Field(default=1, ge=1, le=10_000)
|
page: int = Field(default=1, ge=1, le=10_000)
|
||||||
limit: int = Field(default=25, ge=1, le=100)
|
limit: int = Field(default=25, ge=1, le=100)
|
||||||
|
|
||||||
@@ -63,7 +99,7 @@ class SearchCodeArgs(RepositoryArgs):
|
|||||||
class ListCommitsArgs(RepositoryArgs):
|
class ListCommitsArgs(RepositoryArgs):
|
||||||
"""Arguments for list_commits."""
|
"""Arguments for list_commits."""
|
||||||
|
|
||||||
ref: str = Field(default="main", min_length=1, max_length=200)
|
ref: GitRef = Field(default="main", min_length=1, max_length=200)
|
||||||
page: int = Field(default=1, ge=1, le=10_000)
|
page: int = Field(default=1, ge=1, le=10_000)
|
||||||
limit: int = Field(default=25, ge=1, le=100)
|
limit: int = Field(default=25, ge=1, le=100)
|
||||||
|
|
||||||
@@ -71,14 +107,14 @@ class ListCommitsArgs(RepositoryArgs):
|
|||||||
class CommitDiffArgs(RepositoryArgs):
|
class CommitDiffArgs(RepositoryArgs):
|
||||||
"""Arguments for get_commit_diff."""
|
"""Arguments for get_commit_diff."""
|
||||||
|
|
||||||
sha: str = Field(..., min_length=7, max_length=64)
|
sha: GitRef = Field(..., min_length=7, max_length=64)
|
||||||
|
|
||||||
|
|
||||||
class CompareRefsArgs(RepositoryArgs):
|
class CompareRefsArgs(RepositoryArgs):
|
||||||
"""Arguments for compare_refs."""
|
"""Arguments for compare_refs."""
|
||||||
|
|
||||||
base: str = Field(..., min_length=1, max_length=200)
|
base: GitRef = Field(..., min_length=1, max_length=200)
|
||||||
head: str = Field(..., min_length=1, max_length=200)
|
head: GitRef = Field(..., min_length=1, max_length=200)
|
||||||
|
|
||||||
|
|
||||||
class ListIssuesArgs(RepositoryArgs):
|
class ListIssuesArgs(RepositoryArgs):
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ from __future__ import annotations
|
|||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import Request, Response
|
from httpx import Client, Request, Response
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from aegis_gitea_mcp.config import reset_settings
|
from aegis_gitea_mcp.config import reset_settings
|
||||||
from aegis_gitea_mcp.gitea_client import (
|
from aegis_gitea_mcp.gitea_client import (
|
||||||
@@ -15,6 +16,11 @@ from aegis_gitea_mcp.gitea_client import (
|
|||||||
GiteaError,
|
GiteaError,
|
||||||
GiteaNotFoundError,
|
GiteaNotFoundError,
|
||||||
)
|
)
|
||||||
|
from aegis_gitea_mcp.tools.arguments import (
|
||||||
|
CommitDiffArgs,
|
||||||
|
CompareRefsArgs,
|
||||||
|
FileTreeArgs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@@ -166,3 +172,94 @@ async def test_get_file_contents_blocks_oversized_payload(monkeypatch: pytest.Mo
|
|||||||
|
|
||||||
with pytest.raises(GiteaError, match="exceeds limit"):
|
with pytest.raises(GiteaError, match="exceeds limit"):
|
||||||
await client.get_file_contents("acme", "demo", "big.bin")
|
await client.get_file_contents("acme", "demo", "big.bin")
|
||||||
|
|
||||||
|
|
||||||
|
_MALICIOUS_REFS = [
|
||||||
|
"../../../x/y",
|
||||||
|
"..",
|
||||||
|
"/etc/passwd",
|
||||||
|
"a\x00b",
|
||||||
|
"a?b",
|
||||||
|
"a#b",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("value", _MALICIOUS_REFS)
|
||||||
|
def test_file_tree_args_reject_traversal_ref(value: str) -> None:
|
||||||
|
"""Layer 1: FileTreeArgs.ref rejects traversal/unsafe values."""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
FileTreeArgs(owner="o", repo="r", ref=value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("value", _MALICIOUS_REFS)
|
||||||
|
def test_commit_diff_args_reject_traversal_sha(value: str) -> None:
|
||||||
|
"""Layer 1: CommitDiffArgs.sha rejects traversal/unsafe values."""
|
||||||
|
# ".." is shorter than the 7-char min_length; still rejected (length or ref check).
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
CommitDiffArgs(owner="o", repo="r", sha=value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("value", _MALICIOUS_REFS)
|
||||||
|
def test_compare_refs_args_reject_traversal_base(value: str) -> None:
|
||||||
|
"""Layer 1: CompareRefsArgs.base rejects traversal/unsafe values."""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
CompareRefsArgs(owner="o", repo="r", base=value, head="main")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("value", _MALICIOUS_REFS)
|
||||||
|
def test_compare_refs_args_reject_traversal_head(value: str) -> None:
|
||||||
|
"""Layer 1: CompareRefsArgs.head rejects traversal/unsafe values."""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
CompareRefsArgs(owner="o", repo="r", base="main", head=value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_git_refs_allow_slash_containing_refs() -> None:
|
||||||
|
"""Legitimate refs that contain '/' validate successfully."""
|
||||||
|
tree = FileTreeArgs(owner="o", repo="r", ref="feature/foo")
|
||||||
|
assert tree.ref == "feature/foo"
|
||||||
|
|
||||||
|
compare = CompareRefsArgs(owner="o", repo="r", base="release/1.0", head="main")
|
||||||
|
assert compare.base == "release/1.0"
|
||||||
|
assert compare.head == "main"
|
||||||
|
|
||||||
|
|
||||||
|
def test_git_ref_length_bounds_still_enforced() -> None:
|
||||||
|
"""Field min/max length still applies alongside the AfterValidator."""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
FileTreeArgs(owner="o", repo="r", ref="")
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
FileTreeArgs(owner="o", repo="r", ref="a" * 201)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
# sha below 7-char minimum
|
||||||
|
CommitDiffArgs(owner="o", repo="r", sha="abc")
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer2_encoding_confines_unsafe_chars_to_path_segment() -> None:
|
||||||
|
"""Layer 2 defense-in-depth: quote(..., safe='/') confines unsafe chars.
|
||||||
|
|
||||||
|
A sha containing ``?``, ``#``, whitespace, or a backslash that hypothetically
|
||||||
|
bypassed Layer 1 is percent-encoded so it cannot split off a query string,
|
||||||
|
fragment, or extra path component -- it stays inside the single ref segment
|
||||||
|
under the declared ``owner/repo``. (Note: ``..`` collapse is closed by Layer 1
|
||||||
|
validation, since ``quote`` intentionally leaves ``.`` and ``/`` literal to
|
||||||
|
preserve refs like ``feature/foo``.)
|
||||||
|
"""
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
owner, repo = "alice", "repoA"
|
||||||
|
prefix = f"/api/v1/repos/{owner}/{repo}/git/commits/"
|
||||||
|
|
||||||
|
for malicious_sha in ["abc?injected=1", "abc#frag", "abc def", "abc\\..\\evil"]:
|
||||||
|
endpoint = (
|
||||||
|
f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}"
|
||||||
|
f"/git/commits/{quote(malicious_sha, safe='/')}"
|
||||||
|
)
|
||||||
|
with Client(base_url="https://gitea.example.com") as http_client:
|
||||||
|
request = http_client.build_request("GET", endpoint)
|
||||||
|
|
||||||
|
# Stays scoped to declared repo, no query/fragment broke out.
|
||||||
|
assert request.url.path.startswith(prefix)
|
||||||
|
assert request.url.query == b""
|
||||||
|
assert request.url.fragment == ""
|
||||||
|
# Exactly one ref segment after the prefix (no path splitting).
|
||||||
|
assert "/" not in request.url.path[len(prefix) :]
|
||||||
|
|||||||
Reference in New Issue
Block a user