e08ba42697
docker / test (pull_request) Successful in 29s
docker / lint (pull_request) Successful in 35s
lint / lint (pull_request) Successful in 35s
test / test (pull_request) Successful in 35s
docker / docker-test (pull_request) Successful in 8s
docker / docker-publish (pull_request) Has been skipped
test / test (push) Successful in 23s
lint / lint (push) Successful in 23s
Add a `milestone` argument to `create_issue` and `update_issue` accepting either a numeric milestone id or a title (resolved case-insensitively against open and closed milestones, with a clear error for unknown titles). On `update_issue`, `milestone: 0` clears the milestone. A BeforeValidator rejects booleans so they are not silently coerced to an id. Gitea Projects (Kanban boards) were investigated for #22 and are intentionally left unsupported: Gitea 1.26.2 exposes no project endpoints in its REST API. Documented this in api-reference.md and refreshed the (stale) write-mode tool list to cover all 16 write tools. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
414 lines
17 KiB
Python
414 lines
17 KiB
Python
"""Unit tests for Gitea client request behavior."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
from httpx import Client, Request, Response
|
|
from pydantic import ValidationError
|
|
|
|
from aegis_gitea_mcp.config import reset_settings
|
|
from aegis_gitea_mcp.gitea_client import (
|
|
GiteaAuthenticationError,
|
|
GiteaAuthorizationError,
|
|
GiteaClient,
|
|
GiteaError,
|
|
GiteaNotFoundError,
|
|
)
|
|
from aegis_gitea_mcp.tools.arguments import (
|
|
CommitDiffArgs,
|
|
CompareRefsArgs,
|
|
FileTreeArgs,
|
|
)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def gitea_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Provide minimal environment for client initialization."""
|
|
reset_settings()
|
|
monkeypatch.setenv("GITEA_URL", "https://gitea.example.com")
|
|
monkeypatch.setenv("GITEA_TOKEN", "legacy-token")
|
|
monkeypatch.setenv("MCP_API_KEYS", "a" * 64)
|
|
monkeypatch.setenv("ENVIRONMENT", "test")
|
|
yield
|
|
reset_settings()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_client_context_uses_bearer_header() -> None:
|
|
"""HTTP client is created with bearer token and closed on exit."""
|
|
with patch("aegis_gitea_mcp.gitea_client.AsyncClient") as mock_async_client:
|
|
mock_instance = AsyncMock()
|
|
mock_async_client.return_value = mock_instance
|
|
|
|
async with GiteaClient(token="user-oauth-token"):
|
|
pass
|
|
|
|
_, kwargs = mock_async_client.call_args
|
|
assert kwargs["headers"]["Authorization"] == "Bearer user-oauth-token"
|
|
mock_instance.aclose.assert_awaited_once()
|
|
|
|
|
|
def test_client_requires_non_empty_token() -> None:
|
|
"""Client construction fails when token is missing."""
|
|
with pytest.raises(ValueError, match="non-empty"):
|
|
GiteaClient(token=" ")
|
|
|
|
|
|
def test_handle_response_maps_error_codes() -> None:
|
|
"""HTTP status codes map to explicit domain exceptions."""
|
|
client = GiteaClient(token="user-token")
|
|
request = Request("GET", "https://gitea.example.com/api/v1/user")
|
|
|
|
with pytest.raises(GiteaAuthenticationError):
|
|
client._handle_response(Response(401, request=request), correlation_id="c1")
|
|
|
|
with pytest.raises(GiteaAuthorizationError):
|
|
client._handle_response(Response(403, request=request), correlation_id="c2")
|
|
|
|
with pytest.raises(GiteaNotFoundError):
|
|
client._handle_response(Response(404, request=request), correlation_id="c3")
|
|
|
|
with pytest.raises(GiteaError, match="boom"):
|
|
client._handle_response(
|
|
Response(500, request=request, json={"message": "boom"}),
|
|
correlation_id="c4",
|
|
)
|
|
|
|
assert client._handle_response(Response(200, request=request, json={"ok": True}), "c5") == {
|
|
"ok": True
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_public_methods_delegate_to_request_and_normalize() -> None:
|
|
"""Wrapper methods call shared request logic and normalize return types."""
|
|
client = GiteaClient(token="user-token")
|
|
|
|
async def fake_request(method: str, endpoint: str, **kwargs):
|
|
if endpoint == "/api/v1/user":
|
|
return {"login": "alice"}
|
|
if endpoint == "/api/v1/user/repos":
|
|
return [{"name": "repo"}]
|
|
if endpoint == "/api/v1/repos/acme/demo":
|
|
return {"name": "demo"}
|
|
if endpoint == "/api/v1/repos/acme/demo/contents/README.md":
|
|
return {"size": 8, "content": "aGVsbG8=", "encoding": "base64"}
|
|
if endpoint == "/api/v1/repos/acme/demo/git/trees/main":
|
|
return {"tree": [{"path": "README.md"}]}
|
|
if endpoint == "/api/v1/repos/acme/demo/search":
|
|
return {"hits": []}
|
|
if endpoint == "/api/v1/repos/acme/demo/commits":
|
|
return [{"sha": "abc"}]
|
|
if endpoint == "/api/v1/repos/acme/demo/git/commits/abc":
|
|
return {"sha": "abc"}
|
|
if endpoint == "/api/v1/repos/acme/demo/compare/main...feature":
|
|
return {"total_commits": 1}
|
|
if endpoint == "/api/v1/repos/acme/demo/issues":
|
|
if method == "GET":
|
|
return [{"number": 1}]
|
|
return {"number": 12}
|
|
if endpoint == "/api/v1/repos/acme/demo/issues/1":
|
|
if method == "GET":
|
|
return {"number": 1}
|
|
return {"number": 1, "state": "closed"}
|
|
if endpoint == "/api/v1/repos/acme/demo/pulls":
|
|
return [{"number": 2}]
|
|
if endpoint == "/api/v1/repos/acme/demo/pulls/2":
|
|
return {"number": 2}
|
|
if endpoint == "/api/v1/repos/acme/demo/labels":
|
|
return [{"id": 1, "name": "bug"}]
|
|
if endpoint == "/api/v1/repos/acme/demo/tags":
|
|
return [{"name": "v1"}]
|
|
if endpoint == "/api/v1/repos/acme/demo/releases":
|
|
return [{"id": 1}]
|
|
if endpoint == "/api/v1/repos/acme/demo/issues/1/comments":
|
|
return {"id": 9}
|
|
if endpoint == "/api/v1/repos/acme/demo/issues/1/labels":
|
|
return {"labels": [{"name": "bug"}]}
|
|
if endpoint == "/api/v1/repos/acme/demo/issues/1/assignees":
|
|
return {"assignees": [{"login": "alice"}]}
|
|
return {}
|
|
|
|
client._request = AsyncMock(side_effect=fake_request) # type: ignore[method-assign]
|
|
|
|
assert (await client.get_current_user())["login"] == "alice"
|
|
assert len(await client.list_repositories()) == 1
|
|
assert (await client.get_repository("acme", "demo"))["name"] == "demo"
|
|
assert (await client.get_file_contents("acme", "demo", "README.md"))["size"] == 8
|
|
assert len((await client.get_tree("acme", "demo"))["tree"]) == 1
|
|
assert isinstance(
|
|
await client.search_code("acme", "demo", "needle", ref="main", page=1, limit=5), dict
|
|
)
|
|
assert len(await client.list_commits("acme", "demo", ref="main", page=1, limit=5)) == 1
|
|
assert (await client.get_commit_diff("acme", "demo", "abc"))["sha"] == "abc"
|
|
assert isinstance(await client.compare_refs("acme", "demo", "main", "feature"), dict)
|
|
assert len(await client.list_issues("acme", "demo", state="open", page=1, limit=10)) == 1
|
|
assert (await client.get_issue("acme", "demo", 1))["number"] == 1
|
|
assert len(await client.list_pull_requests("acme", "demo", state="open", page=1, limit=10)) == 1
|
|
assert (await client.get_pull_request("acme", "demo", 2))["number"] == 2
|
|
assert len(await client.list_labels("acme", "demo", page=1, limit=10)) == 1
|
|
assert len(await client.list_tags("acme", "demo", page=1, limit=10)) == 1
|
|
assert len(await client.list_releases("acme", "demo", page=1, limit=10)) == 1
|
|
assert (await client.create_issue("acme", "demo", title="Hi", body="Body"))["number"] == 12
|
|
assert (await client.update_issue("acme", "demo", 1, state="closed"))["state"] == "closed"
|
|
assert (await client.create_issue_comment("acme", "demo", 1, "comment"))["id"] == 9
|
|
assert (await client.create_pr_comment("acme", "demo", 1, "comment"))["id"] == 9
|
|
assert isinstance(await client.add_labels("acme", "demo", 1, ["bug"]), dict)
|
|
assert isinstance(await client.assign_issue("acme", "demo", 1, ["alice"]), dict)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_file_contents_blocks_oversized_payload(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""File size limits are enforced before returning content."""
|
|
monkeypatch.setenv("MAX_FILE_SIZE_BYTES", "5")
|
|
reset_settings()
|
|
client = GiteaClient(token="user-token")
|
|
|
|
client._request = AsyncMock( # type: ignore[method-assign]
|
|
return_value={"size": 50, "content": "x", "encoding": "base64"}
|
|
)
|
|
|
|
with pytest.raises(GiteaError, match="exceeds limit"):
|
|
await client.get_file_contents("acme", "demo", "big.bin")
|
|
|
|
|
|
_MALICIOUS_REFS = [
|
|
"../../../x/y",
|
|
"..",
|
|
"/etc/passwd",
|
|
"a\x00b",
|
|
"a?b",
|
|
"a#b",
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("value", _MALICIOUS_REFS)
|
|
def test_file_tree_args_reject_traversal_ref(value: str) -> None:
|
|
"""Layer 1: FileTreeArgs.ref rejects traversal/unsafe values."""
|
|
with pytest.raises(ValidationError):
|
|
FileTreeArgs(owner="o", repo="r", ref=value)
|
|
|
|
|
|
@pytest.mark.parametrize("value", _MALICIOUS_REFS)
|
|
def test_commit_diff_args_reject_traversal_sha(value: str) -> None:
|
|
"""Layer 1: CommitDiffArgs.sha rejects traversal/unsafe values."""
|
|
# ".." is shorter than the 7-char min_length; still rejected (length or ref check).
|
|
with pytest.raises(ValidationError):
|
|
CommitDiffArgs(owner="o", repo="r", sha=value)
|
|
|
|
|
|
@pytest.mark.parametrize("value", _MALICIOUS_REFS)
|
|
def test_compare_refs_args_reject_traversal_base(value: str) -> None:
|
|
"""Layer 1: CompareRefsArgs.base rejects traversal/unsafe values."""
|
|
with pytest.raises(ValidationError):
|
|
CompareRefsArgs(owner="o", repo="r", base=value, head="main")
|
|
|
|
|
|
@pytest.mark.parametrize("value", _MALICIOUS_REFS)
|
|
def test_compare_refs_args_reject_traversal_head(value: str) -> None:
|
|
"""Layer 1: CompareRefsArgs.head rejects traversal/unsafe values."""
|
|
with pytest.raises(ValidationError):
|
|
CompareRefsArgs(owner="o", repo="r", base="main", head=value)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_user_repositories_scopes_by_uid() -> None:
|
|
"""User-scoped listing resolves the uid and filters repo search by it."""
|
|
client = GiteaClient(token="service-pat")
|
|
captured: dict = {}
|
|
|
|
async def fake_request(method: str, endpoint: str, **kwargs):
|
|
if endpoint == "/api/v1/users/alice":
|
|
return {"id": 7, "login": "alice"}
|
|
if endpoint == "/api/v1/repos/search":
|
|
captured["params"] = kwargs.get("params")
|
|
return {"ok": True, "data": [{"full_name": "alice/demo"}, "not-a-dict"]}
|
|
return {}
|
|
|
|
client._request = AsyncMock(side_effect=fake_request) # type: ignore[method-assign]
|
|
|
|
repos = await client.list_user_repositories("alice")
|
|
assert captured["params"]["uid"] == 7
|
|
assert repos == [{"full_name": "alice/demo"}]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_user_repositories_unknown_user_returns_empty() -> None:
|
|
"""A user that cannot be resolved yields an empty list, not an error."""
|
|
client = GiteaClient(token="service-pat")
|
|
|
|
async def fake_request(method: str, endpoint: str, **kwargs):
|
|
return {} # no id field
|
|
|
|
client._request = AsyncMock(side_effect=fake_request) # type: ignore[method-assign]
|
|
assert await client.list_user_repositories("ghost") == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_label_ids_maps_names_case_insensitively() -> None:
|
|
"""Label names are resolved to ids regardless of case."""
|
|
client = GiteaClient(token="user-token")
|
|
|
|
async def fake_request(method: str, endpoint: str, **kwargs):
|
|
return [{"id": 3, "name": "Bug"}, {"id": 9, "name": "wontfix"}]
|
|
|
|
client._request = AsyncMock(side_effect=fake_request) # type: ignore[method-assign]
|
|
ids = await client._resolve_label_ids("o", "r", ["bug", "WONTFIX"], correlation_id="c")
|
|
assert ids == [3, 9]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_label_ids_rejects_unknown_label() -> None:
|
|
"""An unknown label name raises a clear error instead of a silent failure."""
|
|
client = GiteaClient(token="user-token")
|
|
|
|
async def fake_request(method: str, endpoint: str, **kwargs):
|
|
return [{"id": 3, "name": "bug"}]
|
|
|
|
client._request = AsyncMock(side_effect=fake_request) # type: ignore[method-assign]
|
|
with pytest.raises(GiteaError, match="Unknown label"):
|
|
await client._resolve_label_ids("o", "r", ["ghost"], correlation_id="c")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_milestone_id_passes_through_integer() -> None:
|
|
"""An integer milestone reference is used as a Gitea milestone id as-is."""
|
|
client = GiteaClient(token="user-token")
|
|
client._request = AsyncMock() # type: ignore[method-assign]
|
|
assert await client._resolve_milestone_id("o", "r", 7, correlation_id="c") == 7
|
|
# Integer ids must not trigger a milestone lookup.
|
|
client._request.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_milestone_id_maps_title_case_insensitively() -> None:
|
|
"""A milestone title is resolved to its id regardless of case."""
|
|
client = GiteaClient(token="user-token")
|
|
|
|
async def fake_request(method: str, endpoint: str, **kwargs):
|
|
return [{"id": 11, "title": "Sprint 1"}, {"id": 12, "title": "Backlog"}]
|
|
|
|
client._request = AsyncMock(side_effect=fake_request) # type: ignore[method-assign]
|
|
resolved = await client._resolve_milestone_id("o", "r", "sprint 1", correlation_id="c")
|
|
assert resolved == 11
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_milestone_id_rejects_unknown_title() -> None:
|
|
"""An unknown milestone title raises a clear error."""
|
|
client = GiteaClient(token="user-token")
|
|
|
|
async def fake_request(method: str, endpoint: str, **kwargs):
|
|
return [{"id": 11, "title": "Sprint 1"}]
|
|
|
|
client._request = AsyncMock(side_effect=fake_request) # type: ignore[method-assign]
|
|
with pytest.raises(GiteaError, match="Unknown milestone"):
|
|
await client._resolve_milestone_id("o", "r", "Sprint 2", correlation_id="c")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_issue_resolves_milestone_title() -> None:
|
|
"""create_issue resolves a milestone title to an id in the POST payload."""
|
|
client = GiteaClient(token="user-token")
|
|
captured: dict = {}
|
|
|
|
async def fake_request(method: str, endpoint: str, **kwargs):
|
|
if endpoint.endswith("/milestones") and method == "GET":
|
|
return [{"id": 11, "title": "Sprint 1"}]
|
|
if endpoint.endswith("/issues") and method == "POST":
|
|
captured["payload"] = kwargs.get("json_body")
|
|
return {"number": 1, "title": "Issue", "state": "open"}
|
|
return {}
|
|
|
|
client._request = AsyncMock(side_effect=fake_request) # type: ignore[method-assign]
|
|
await client.create_issue("o", "r", title="Issue", body="", milestone="Sprint 1")
|
|
assert captured["payload"]["milestone"] == 11
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_issue_clears_milestone_with_zero() -> None:
|
|
"""update_issue forwards milestone id 0 verbatim to clear the milestone."""
|
|
client = GiteaClient(token="user-token")
|
|
captured: dict = {}
|
|
|
|
async def fake_request(method: str, endpoint: str, **kwargs):
|
|
captured["payload"] = kwargs.get("json_body")
|
|
return {"number": 1, "title": "Issue", "state": "open"}
|
|
|
|
client._request = AsyncMock(side_effect=fake_request) # type: ignore[method-assign]
|
|
await client.update_issue("o", "r", 1, milestone=0)
|
|
assert captured["payload"]["milestone"] == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_labels_resolves_names_to_ids() -> None:
|
|
"""add_labels translates names to ids before POSTing to Gitea."""
|
|
client = GiteaClient(token="user-token")
|
|
captured: dict = {}
|
|
|
|
async def fake_request(method: str, endpoint: str, **kwargs):
|
|
if endpoint.endswith("/labels") and method == "GET":
|
|
return [{"id": 42, "name": "bug"}]
|
|
if endpoint.endswith("/issues/1/labels") and method == "POST":
|
|
captured["body"] = kwargs.get("json_body")
|
|
return {"labels": [{"name": "bug"}]}
|
|
return {}
|
|
|
|
client._request = AsyncMock(side_effect=fake_request) # type: ignore[method-assign]
|
|
await client.add_labels("o", "r", 1, ["bug"])
|
|
assert captured["body"] == {"labels": [42]}
|
|
|
|
|
|
def test_git_refs_allow_slash_containing_refs() -> None:
|
|
"""Legitimate refs that contain '/' validate successfully."""
|
|
tree = FileTreeArgs(owner="o", repo="r", ref="feature/foo")
|
|
assert tree.ref == "feature/foo"
|
|
|
|
compare = CompareRefsArgs(owner="o", repo="r", base="release/1.0", head="main")
|
|
assert compare.base == "release/1.0"
|
|
assert compare.head == "main"
|
|
|
|
|
|
def test_git_ref_length_bounds_still_enforced() -> None:
|
|
"""Field min/max length still applies alongside the AfterValidator."""
|
|
with pytest.raises(ValidationError):
|
|
FileTreeArgs(owner="o", repo="r", ref="")
|
|
with pytest.raises(ValidationError):
|
|
FileTreeArgs(owner="o", repo="r", ref="a" * 201)
|
|
with pytest.raises(ValidationError):
|
|
# sha below 7-char minimum
|
|
CommitDiffArgs(owner="o", repo="r", sha="abc")
|
|
|
|
|
|
def test_layer2_encoding_confines_unsafe_chars_to_path_segment() -> None:
|
|
"""Layer 2 defense-in-depth: quote(..., safe='/') confines unsafe chars.
|
|
|
|
A sha containing ``?``, ``#``, whitespace, or a backslash that hypothetically
|
|
bypassed Layer 1 is percent-encoded so it cannot split off a query string,
|
|
fragment, or extra path component -- it stays inside the single ref segment
|
|
under the declared ``owner/repo``. (Note: ``..`` collapse is closed by Layer 1
|
|
validation, since ``quote`` intentionally leaves ``.`` and ``/`` literal to
|
|
preserve refs like ``feature/foo``.)
|
|
"""
|
|
from urllib.parse import quote
|
|
|
|
owner, repo = "alice", "repoA"
|
|
prefix = f"/api/v1/repos/{owner}/{repo}/git/commits/"
|
|
|
|
for malicious_sha in ["abc?injected=1", "abc#frag", "abc def", "abc\\..\\evil"]:
|
|
endpoint = (
|
|
f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}"
|
|
f"/git/commits/{quote(malicious_sha, safe='/')}"
|
|
)
|
|
with Client(base_url="https://gitea.example.com") as http_client:
|
|
request = http_client.build_request("GET", endpoint)
|
|
|
|
# Stays scoped to declared repo, no query/fragment broke out.
|
|
assert request.url.path.startswith(prefix)
|
|
assert request.url.query == b""
|
|
assert request.url.fragment == ""
|
|
# Exactly one ref segment after the prefix (no path splitting).
|
|
assert "/" not in request.url.path[len(prefix) :]
|