"""Unit tests for Gitea client request behavior.""" from __future__ import annotations from unittest.mock import AsyncMock, patch import pytest from httpx import Client, Request, Response from pydantic import ValidationError from aegis_gitea_mcp.config import reset_settings from aegis_gitea_mcp.gitea_client import ( GiteaAuthenticationError, GiteaAuthorizationError, GiteaClient, GiteaError, GiteaNotFoundError, ) from aegis_gitea_mcp.tools.arguments import ( CommitDiffArgs, CompareRefsArgs, FileTreeArgs, ) @pytest.fixture(autouse=True) def gitea_env(monkeypatch: pytest.MonkeyPatch) -> None: """Provide minimal environment for client initialization.""" reset_settings() monkeypatch.setenv("GITEA_URL", "https://gitea.example.com") monkeypatch.setenv("GITEA_TOKEN", "legacy-token") monkeypatch.setenv("MCP_API_KEYS", "a" * 64) monkeypatch.setenv("ENVIRONMENT", "test") yield reset_settings() @pytest.mark.asyncio async def test_client_context_uses_bearer_header() -> None: """HTTP client is created with bearer token and closed on exit.""" with patch("aegis_gitea_mcp.gitea_client.AsyncClient") as mock_async_client: mock_instance = AsyncMock() mock_async_client.return_value = mock_instance async with GiteaClient(token="user-oauth-token"): pass _, kwargs = mock_async_client.call_args assert kwargs["headers"]["Authorization"] == "Bearer user-oauth-token" mock_instance.aclose.assert_awaited_once() def test_client_requires_non_empty_token() -> None: """Client construction fails when token is missing.""" with pytest.raises(ValueError, match="non-empty"): GiteaClient(token=" ") def test_handle_response_maps_error_codes() -> None: """HTTP status codes map to explicit domain exceptions.""" client = GiteaClient(token="user-token") request = Request("GET", "https://gitea.example.com/api/v1/user") with pytest.raises(GiteaAuthenticationError): client._handle_response(Response(401, request=request), correlation_id="c1") with pytest.raises(GiteaAuthorizationError): client._handle_response(Response(403, request=request), correlation_id="c2") with pytest.raises(GiteaNotFoundError): client._handle_response(Response(404, request=request), correlation_id="c3") with pytest.raises(GiteaError, match="boom"): client._handle_response( Response(500, request=request, json={"message": "boom"}), correlation_id="c4", ) assert client._handle_response(Response(200, request=request, json={"ok": True}), "c5") == { "ok": True } @pytest.mark.asyncio async def test_public_methods_delegate_to_request_and_normalize() -> None: """Wrapper methods call shared request logic and normalize return types.""" client = GiteaClient(token="user-token") async def fake_request(method: str, endpoint: str, **kwargs): if endpoint == "/api/v1/user": return {"login": "alice"} if endpoint == "/api/v1/user/repos": return [{"name": "repo"}] if endpoint == "/api/v1/repos/acme/demo": return {"name": "demo"} if endpoint == "/api/v1/repos/acme/demo/contents/README.md": return {"size": 8, "content": "aGVsbG8=", "encoding": "base64"} if endpoint == "/api/v1/repos/acme/demo/git/trees/main": return {"tree": [{"path": "README.md"}]} if endpoint == "/api/v1/repos/acme/demo/search": return {"hits": []} if endpoint == "/api/v1/repos/acme/demo/commits": return [{"sha": "abc"}] if endpoint == "/api/v1/repos/acme/demo/git/commits/abc": return {"sha": "abc"} if endpoint == "/api/v1/repos/acme/demo/compare/main...feature": return {"total_commits": 1} if endpoint == "/api/v1/repos/acme/demo/issues": if method == "GET": return [{"number": 1}] return {"number": 12} if endpoint == "/api/v1/repos/acme/demo/issues/1": if method == "GET": return {"number": 1} return {"number": 1, "state": "closed"} if endpoint == "/api/v1/repos/acme/demo/pulls": return [{"number": 2}] if endpoint == "/api/v1/repos/acme/demo/pulls/2": return {"number": 2} if endpoint == "/api/v1/repos/acme/demo/labels": return [{"name": "bug"}] if endpoint == "/api/v1/repos/acme/demo/tags": return [{"name": "v1"}] if endpoint == "/api/v1/repos/acme/demo/releases": return [{"id": 1}] if endpoint == "/api/v1/repos/acme/demo/issues/1/comments": return {"id": 9} if endpoint == "/api/v1/repos/acme/demo/issues/1/labels": return {"labels": [{"name": "bug"}]} if endpoint == "/api/v1/repos/acme/demo/issues/1/assignees": return {"assignees": [{"login": "alice"}]} return {} client._request = AsyncMock(side_effect=fake_request) # type: ignore[method-assign] assert (await client.get_current_user())["login"] == "alice" assert len(await client.list_repositories()) == 1 assert (await client.get_repository("acme", "demo"))["name"] == "demo" assert (await client.get_file_contents("acme", "demo", "README.md"))["size"] == 8 assert len((await client.get_tree("acme", "demo"))["tree"]) == 1 assert isinstance( await client.search_code("acme", "demo", "needle", ref="main", page=1, limit=5), dict ) assert len(await client.list_commits("acme", "demo", ref="main", page=1, limit=5)) == 1 assert (await client.get_commit_diff("acme", "demo", "abc"))["sha"] == "abc" assert isinstance(await client.compare_refs("acme", "demo", "main", "feature"), dict) assert len(await client.list_issues("acme", "demo", state="open", page=1, limit=10)) == 1 assert (await client.get_issue("acme", "demo", 1))["number"] == 1 assert len(await client.list_pull_requests("acme", "demo", state="open", page=1, limit=10)) == 1 assert (await client.get_pull_request("acme", "demo", 2))["number"] == 2 assert len(await client.list_labels("acme", "demo", page=1, limit=10)) == 1 assert len(await client.list_tags("acme", "demo", page=1, limit=10)) == 1 assert len(await client.list_releases("acme", "demo", page=1, limit=10)) == 1 assert (await client.create_issue("acme", "demo", title="Hi", body="Body"))["number"] == 12 assert (await client.update_issue("acme", "demo", 1, state="closed"))["state"] == "closed" assert (await client.create_issue_comment("acme", "demo", 1, "comment"))["id"] == 9 assert (await client.create_pr_comment("acme", "demo", 1, "comment"))["id"] == 9 assert isinstance(await client.add_labels("acme", "demo", 1, ["bug"]), dict) assert isinstance(await client.assign_issue("acme", "demo", 1, ["alice"]), dict) @pytest.mark.asyncio async def test_get_file_contents_blocks_oversized_payload(monkeypatch: pytest.MonkeyPatch) -> None: """File size limits are enforced before returning content.""" monkeypatch.setenv("MAX_FILE_SIZE_BYTES", "5") reset_settings() client = GiteaClient(token="user-token") client._request = AsyncMock( # type: ignore[method-assign] return_value={"size": 50, "content": "x", "encoding": "base64"} ) with pytest.raises(GiteaError, match="exceeds limit"): await client.get_file_contents("acme", "demo", "big.bin") _MALICIOUS_REFS = [ "../../../x/y", "..", "/etc/passwd", "a\x00b", "a?b", "a#b", ] @pytest.mark.parametrize("value", _MALICIOUS_REFS) def test_file_tree_args_reject_traversal_ref(value: str) -> None: """Layer 1: FileTreeArgs.ref rejects traversal/unsafe values.""" with pytest.raises(ValidationError): FileTreeArgs(owner="o", repo="r", ref=value) @pytest.mark.parametrize("value", _MALICIOUS_REFS) def test_commit_diff_args_reject_traversal_sha(value: str) -> None: """Layer 1: CommitDiffArgs.sha rejects traversal/unsafe values.""" # ".." is shorter than the 7-char min_length; still rejected (length or ref check). with pytest.raises(ValidationError): CommitDiffArgs(owner="o", repo="r", sha=value) @pytest.mark.parametrize("value", _MALICIOUS_REFS) def test_compare_refs_args_reject_traversal_base(value: str) -> None: """Layer 1: CompareRefsArgs.base rejects traversal/unsafe values.""" with pytest.raises(ValidationError): CompareRefsArgs(owner="o", repo="r", base=value, head="main") @pytest.mark.parametrize("value", _MALICIOUS_REFS) def test_compare_refs_args_reject_traversal_head(value: str) -> None: """Layer 1: CompareRefsArgs.head rejects traversal/unsafe values.""" with pytest.raises(ValidationError): CompareRefsArgs(owner="o", repo="r", base="main", head=value) @pytest.mark.asyncio async def test_list_user_repositories_scopes_by_uid() -> None: """User-scoped listing resolves the uid and filters repo search by it.""" client = GiteaClient(token="service-pat") captured: dict = {} async def fake_request(method: str, endpoint: str, **kwargs): if endpoint == "/api/v1/users/alice": return {"id": 7, "login": "alice"} if endpoint == "/api/v1/repos/search": captured["params"] = kwargs.get("params") return {"ok": True, "data": [{"full_name": "alice/demo"}, "not-a-dict"]} return {} client._request = AsyncMock(side_effect=fake_request) # type: ignore[method-assign] repos = await client.list_user_repositories("alice") assert captured["params"]["uid"] == 7 assert repos == [{"full_name": "alice/demo"}] @pytest.mark.asyncio async def test_list_user_repositories_unknown_user_returns_empty() -> None: """A user that cannot be resolved yields an empty list, not an error.""" client = GiteaClient(token="service-pat") async def fake_request(method: str, endpoint: str, **kwargs): return {} # no id field client._request = AsyncMock(side_effect=fake_request) # type: ignore[method-assign] assert await client.list_user_repositories("ghost") == [] def test_git_refs_allow_slash_containing_refs() -> None: """Legitimate refs that contain '/' validate successfully.""" tree = FileTreeArgs(owner="o", repo="r", ref="feature/foo") assert tree.ref == "feature/foo" compare = CompareRefsArgs(owner="o", repo="r", base="release/1.0", head="main") assert compare.base == "release/1.0" assert compare.head == "main" def test_git_ref_length_bounds_still_enforced() -> None: """Field min/max length still applies alongside the AfterValidator.""" with pytest.raises(ValidationError): FileTreeArgs(owner="o", repo="r", ref="") with pytest.raises(ValidationError): FileTreeArgs(owner="o", repo="r", ref="a" * 201) with pytest.raises(ValidationError): # sha below 7-char minimum CommitDiffArgs(owner="o", repo="r", sha="abc") def test_layer2_encoding_confines_unsafe_chars_to_path_segment() -> None: """Layer 2 defense-in-depth: quote(..., safe='/') confines unsafe chars. A sha containing ``?``, ``#``, whitespace, or a backslash that hypothetically bypassed Layer 1 is percent-encoded so it cannot split off a query string, fragment, or extra path component -- it stays inside the single ref segment under the declared ``owner/repo``. (Note: ``..`` collapse is closed by Layer 1 validation, since ``quote`` intentionally leaves ``.`` and ``/`` literal to preserve refs like ``feature/foo``.) """ from urllib.parse import quote owner, repo = "alice", "repoA" prefix = f"/api/v1/repos/{owner}/{repo}/git/commits/" for malicious_sha in ["abc?injected=1", "abc#frag", "abc def", "abc\\..\\evil"]: endpoint = ( f"/api/v1/repos/{quote(owner, safe='')}/{quote(repo, safe='')}" f"/git/commits/{quote(malicious_sha, safe='/')}" ) with Client(base_url="https://gitea.example.com") as http_client: request = http_client.build_request("GET", endpoint) # Stays scoped to declared repo, no query/fragment broke out. assert request.url.path.startswith(prefix) assert request.url.query == b"" assert request.url.fragment == "" # Exactly one ref segment after the prefix (no path splitting). assert "/" not in request.url.path[len(prefix) :]