Files
AegisGitea-MCP/src/aegis_gitea_mcp/server.py
2026-02-14 18:18:34 +01:00

595 lines
20 KiB
Python

"""Main MCP server implementation with hardened security controls."""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse
from pydantic import BaseModel, Field, ValidationError
from aegis_gitea_mcp.audit import get_audit_logger
from aegis_gitea_mcp.auth import get_validator
from aegis_gitea_mcp.automation import AutomationError, AutomationManager
from aegis_gitea_mcp.config import get_settings
from aegis_gitea_mcp.gitea_client import (
GiteaAuthenticationError,
GiteaAuthorizationError,
GiteaClient,
)
from aegis_gitea_mcp.logging_utils import configure_logging
from aegis_gitea_mcp.mcp_protocol import (
AVAILABLE_TOOLS,
MCPListToolsResponse,
MCPToolCallRequest,
MCPToolCallResponse,
get_tool_by_name,
)
from aegis_gitea_mcp.observability import get_metrics_registry, monotonic_seconds
from aegis_gitea_mcp.policy import PolicyError, get_policy_engine
from aegis_gitea_mcp.rate_limit import get_rate_limiter
from aegis_gitea_mcp.request_context import set_request_id
from aegis_gitea_mcp.security import sanitize_data
from aegis_gitea_mcp.tools.arguments import extract_repository, extract_target_path
from aegis_gitea_mcp.tools.read_tools import (
compare_refs_tool,
get_commit_diff_tool,
get_issue_tool,
get_pull_request_tool,
list_commits_tool,
list_issues_tool,
list_labels_tool,
list_pull_requests_tool,
list_releases_tool,
list_tags_tool,
search_code_tool,
)
from aegis_gitea_mcp.tools.repository import (
get_file_contents_tool,
get_file_tree_tool,
get_repository_info_tool,
list_repositories_tool,
)
from aegis_gitea_mcp.tools.write_tools import (
add_labels_tool,
assign_issue_tool,
create_issue_comment_tool,
create_issue_tool,
create_pr_comment_tool,
update_issue_tool,
)
logger = logging.getLogger(__name__)
app = FastAPI(
title="AegisGitea MCP Server",
description="Security-first MCP server for controlled AI access to self-hosted Gitea",
version="0.2.0",
)
class AutomationWebhookRequest(BaseModel):
"""Request body for automation webhook ingestion."""
event_type: str = Field(..., min_length=1, max_length=128)
payload: dict[str, Any] = Field(default_factory=dict)
repository: str | None = Field(default=None)
class AutomationJobRequest(BaseModel):
"""Request body for automation job execution."""
job_name: str = Field(..., min_length=1, max_length=128)
owner: str = Field(..., min_length=1, max_length=100)
repo: str = Field(..., min_length=1, max_length=100)
finding_title: str | None = Field(default=None, max_length=256)
finding_body: str | None = Field(default=None, max_length=10_000)
ToolHandler = Callable[[GiteaClient, dict[str, Any]], Awaitable[dict[str, Any]]]
TOOL_HANDLERS: dict[str, ToolHandler] = {
# Baseline read tools
"list_repositories": list_repositories_tool,
"get_repository_info": get_repository_info_tool,
"get_file_tree": get_file_tree_tool,
"get_file_contents": get_file_contents_tool,
# Expanded read tools
"search_code": search_code_tool,
"list_commits": list_commits_tool,
"get_commit_diff": get_commit_diff_tool,
"compare_refs": compare_refs_tool,
"list_issues": list_issues_tool,
"get_issue": get_issue_tool,
"list_pull_requests": list_pull_requests_tool,
"get_pull_request": get_pull_request_tool,
"list_labels": list_labels_tool,
"list_tags": list_tags_tool,
"list_releases": list_releases_tool,
# Write-mode tools
"create_issue": create_issue_tool,
"update_issue": update_issue_tool,
"create_issue_comment": create_issue_comment_tool,
"create_pr_comment": create_pr_comment_tool,
"add_labels": add_labels_tool,
"assign_issue": assign_issue_tool,
}
@app.middleware("http")
async def request_context_middleware(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
"""Attach request correlation context and collect request metrics."""
request_id = request.headers.get("x-request-id") or str(uuid.uuid4())
set_request_id(request_id)
request.state.request_id = request_id
started_at = monotonic_seconds()
status_code = 500
try:
response = await call_next(request)
status_code = response.status_code
response.headers["X-Request-ID"] = request_id
return response
finally:
duration = max(monotonic_seconds() - started_at, 0.0)
logger.debug(
"request_completed",
extra={
"method": request.method,
"path": request.url.path,
"duration_seconds": duration,
"status_code": status_code,
},
)
metrics = get_metrics_registry()
metrics.record_http_request(request.method, request.url.path, status_code)
@app.middleware("http")
async def authenticate_and_rate_limit(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
"""Apply rate-limiting and authentication for MCP endpoints."""
settings = get_settings()
if request.url.path in {"/", "/health"}:
return await call_next(request)
if request.url.path == "/metrics" and settings.metrics_enabled:
# Metrics endpoint is intentionally left unauthenticated for pull-based scraping.
return await call_next(request)
if not (request.url.path.startswith("/mcp/") or request.url.path.startswith("/automation/")):
return await call_next(request)
validator = get_validator()
limiter = get_rate_limiter()
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
auth_header = request.headers.get("authorization")
api_key = validator.extract_bearer_token(auth_header)
if not api_key and request.url.path in {"/mcp/tool/call", "/mcp/sse"}:
api_key = request.query_params.get("api_key")
rate_limit = limiter.check(client_ip=client_ip, token=api_key)
if not rate_limit.allowed:
return JSONResponse(
status_code=429,
content={
"error": "Rate limit exceeded",
"message": rate_limit.reason,
"request_id": getattr(request.state, "request_id", "-"),
},
)
# Mixed mode: tool discovery remains public to preserve MCP client compatibility.
if request.url.path == "/mcp/tools":
return await call_next(request)
is_valid, error_message = validator.validate_api_key(api_key, client_ip, user_agent)
if not is_valid:
return JSONResponse(
status_code=401,
content={
"error": "Authentication failed",
"message": error_message,
"detail": "Provide Authorization: Bearer <api-key> or ?api_key=<api-key>",
"request_id": getattr(request.state, "request_id", "-"),
},
)
return await call_next(request)
@app.on_event("startup")
async def startup_event() -> None:
"""Initialize server state on startup."""
settings = get_settings()
configure_logging(settings.log_level)
logger.info("server_starting")
logger.info(
"server_configuration",
extra={
"host": settings.mcp_host,
"port": settings.mcp_port,
"gitea_url": settings.gitea_base_url,
"auth_enabled": settings.auth_enabled,
"write_mode": settings.write_mode,
"metrics_enabled": settings.metrics_enabled,
},
)
# Fail-fast policy parse errors at startup.
try:
_ = get_policy_engine()
except PolicyError:
logger.error("policy_load_failed")
raise
if settings.startup_validate_gitea and settings.environment != "test":
try:
async with GiteaClient() as gitea:
user = await gitea.get_current_user()
logger.info("gitea_connected", extra={"bot_user": user.get("login", "unknown")})
except GiteaAuthenticationError as exc:
logger.error("gitea_connection_failed_authentication")
raise RuntimeError(
"Startup validation failed: Gitea authentication was rejected. Check GITEA_TOKEN."
) from exc
except GiteaAuthorizationError as exc:
logger.error("gitea_connection_failed_authorization")
raise RuntimeError(
"Startup validation failed: Gitea token lacks permission for /api/v1/user."
) from exc
except Exception as exc:
logger.error("gitea_connection_failed")
raise RuntimeError("Startup validation failed: unable to connect to Gitea.") from exc
@app.on_event("shutdown")
async def shutdown_event() -> None:
"""Log server shutdown event."""
logger.info("server_stopping")
@app.get("/")
async def root() -> dict[str, Any]:
"""Root endpoint with server metadata."""
return {
"name": "AegisGitea MCP Server",
"version": "0.2.0",
"status": "running",
"mcp_version": "1.0",
}
@app.get("/health")
async def health() -> dict[str, str]:
"""Health check endpoint."""
return {"status": "healthy"}
@app.get("/metrics")
async def metrics() -> PlainTextResponse:
"""Prometheus-compatible metrics endpoint."""
settings = get_settings()
if not settings.metrics_enabled:
raise HTTPException(status_code=404, detail="Metrics endpoint disabled")
data = get_metrics_registry().render_prometheus()
return PlainTextResponse(content=data, media_type="text/plain; version=0.0.4")
@app.post("/automation/webhook")
async def automation_webhook(request: AutomationWebhookRequest) -> JSONResponse:
"""Ingest policy-controlled automation webhooks."""
manager = AutomationManager()
try:
result = await manager.handle_webhook(
event_type=request.event_type,
payload=request.payload,
repository=request.repository,
)
return JSONResponse(content={"success": True, "result": result})
except AutomationError as exc:
raise HTTPException(status_code=403, detail=str(exc)) from exc
@app.post("/automation/jobs/run")
async def automation_run_job(request: AutomationJobRequest) -> JSONResponse:
"""Execute a policy-controlled automation job for a repository."""
manager = AutomationManager()
try:
result = await manager.run_job(
job_name=request.job_name,
owner=request.owner,
repo=request.repo,
finding_title=request.finding_title,
finding_body=request.finding_body,
)
return JSONResponse(content={"success": True, "result": result})
except AutomationError as exc:
raise HTTPException(status_code=403, detail=str(exc)) from exc
@app.get("/mcp/tools")
async def list_tools() -> JSONResponse:
"""List all available MCP tools."""
response = MCPListToolsResponse(tools=AVAILABLE_TOOLS)
return JSONResponse(content=response.model_dump(by_alias=True))
async def _execute_tool_call(
tool_name: str, arguments: dict[str, Any], correlation_id: str
) -> dict[str, Any]:
"""Execute tool call with policy checks and standardized response sanitization."""
settings = get_settings()
audit = get_audit_logger()
metrics = get_metrics_registry()
tool_def = get_tool_by_name(tool_name)
if not tool_def:
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
handler = TOOL_HANDLERS.get(tool_name)
if not handler:
raise HTTPException(
status_code=500, detail=f"Tool '{tool_name}' has no handler implementation"
)
repository = extract_repository(arguments)
target_path = extract_target_path(arguments)
decision = get_policy_engine().authorize(
tool_name=tool_name,
is_write=tool_def.write_operation,
repository=repository,
target_path=target_path,
)
if not decision.allowed:
audit.log_access_denied(
tool_name=tool_name,
repository=repository,
reason=decision.reason,
correlation_id=correlation_id,
)
raise HTTPException(status_code=403, detail=f"Policy denied request: {decision.reason}")
started_at = monotonic_seconds()
status = "error"
try:
async with GiteaClient() as gitea:
result = await handler(gitea, arguments)
if settings.secret_detection_mode != "off":
# Security decision: sanitize outbound payloads to prevent accidental secret exfiltration.
result = sanitize_data(result, mode=settings.secret_detection_mode)
status = "success"
return result
finally:
duration = max(monotonic_seconds() - started_at, 0.0)
metrics.record_tool_call(tool_name, status, duration)
@app.post("/mcp/tool/call")
async def call_tool(request: MCPToolCallRequest) -> JSONResponse:
"""Execute an MCP tool call."""
settings = get_settings()
audit = get_audit_logger()
correlation_id = request.correlation_id or audit.log_tool_invocation(
tool_name=request.tool,
params=request.arguments,
)
try:
result = await _execute_tool_call(request.tool, request.arguments, correlation_id)
audit.log_tool_invocation(
tool_name=request.tool,
correlation_id=correlation_id,
result_status="success",
)
return JSONResponse(
content=MCPToolCallResponse(
success=True,
result=result,
correlation_id=correlation_id,
).model_dump()
)
except HTTPException as exc:
audit.log_tool_invocation(
tool_name=request.tool,
correlation_id=correlation_id,
result_status="error",
error=str(exc.detail),
)
raise
except ValidationError as exc:
error_message = "Invalid tool arguments"
if settings.expose_error_details:
error_message = f"{error_message}: {exc}"
audit.log_tool_invocation(
tool_name=request.tool,
correlation_id=correlation_id,
result_status="error",
error="validation_error",
)
raise HTTPException(status_code=400, detail=error_message) from exc
except Exception:
# Security decision: do not leak stack traces or raw exception messages.
error_message = "Internal server error"
if settings.expose_error_details:
error_message = "Internal server error (details hidden unless explicitly enabled)"
audit.log_tool_invocation(
tool_name=request.tool,
correlation_id=correlation_id,
result_status="error",
error="internal_error",
)
logger.exception("tool_execution_failed")
return JSONResponse(
status_code=500,
content=MCPToolCallResponse(
success=False,
error=error_message,
correlation_id=correlation_id,
).model_dump(),
)
@app.get("/mcp/sse")
async def sse_endpoint(request: Request) -> StreamingResponse:
"""Server-Sent Events endpoint for MCP transport."""
async def event_stream() -> AsyncGenerator[str, None]:
yield (
"data: "
+ json.dumps(
{"event": "connected", "server": "AegisGitea MCP", "version": "0.2.0"},
separators=(",", ":"),
)
+ "\n\n"
)
try:
while True:
if await request.is_disconnected():
break
yield 'data: {"event":"heartbeat"}\n\n'
await asyncio.sleep(30)
except Exception:
logger.exception("sse_stream_error")
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@app.post("/mcp/sse")
async def sse_message_handler(request: Request) -> JSONResponse:
"""Handle POST messages for MCP SSE transport."""
settings = get_settings()
audit = get_audit_logger()
try:
body = await request.json()
message_type = body.get("type") or body.get("method")
message_id = body.get("id")
if message_type == "initialize":
return JSONResponse(
content={
"jsonrpc": "2.0",
"id": message_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"serverInfo": {"name": "AegisGitea MCP", "version": "0.2.0"},
},
}
)
if message_type == "tools/list":
response = MCPListToolsResponse(tools=AVAILABLE_TOOLS)
return JSONResponse(
content={
"jsonrpc": "2.0",
"id": message_id,
"result": response.model_dump(by_alias=True),
}
)
if message_type == "tools/call":
tool_name = body.get("params", {}).get("name")
tool_args = body.get("params", {}).get("arguments", {})
correlation_id = audit.log_tool_invocation(tool_name=tool_name, params=tool_args)
try:
result = await _execute_tool_call(str(tool_name), tool_args, correlation_id)
audit.log_tool_invocation(
tool_name=str(tool_name),
correlation_id=correlation_id,
result_status="success",
)
return JSONResponse(
content={
"jsonrpc": "2.0",
"id": message_id,
"result": {"content": [{"type": "text", "text": json.dumps(result)}]},
}
)
except Exception as exc:
audit.log_tool_invocation(
tool_name=str(tool_name),
correlation_id=correlation_id,
result_status="error",
error=str(exc),
)
message = "Internal server error"
if settings.expose_error_details:
message = str(exc)
return JSONResponse(
content={
"jsonrpc": "2.0",
"id": message_id,
"error": {"code": -32603, "message": message},
}
)
if isinstance(message_type, str) and message_type.startswith("notifications/"):
return JSONResponse(content={})
return JSONResponse(
content={"jsonrpc": "2.0", "id": message_id, "result": {"acknowledged": True}}
)
except Exception:
logger.exception("sse_message_handler_error")
message = "Invalid message format"
if settings.expose_error_details:
message = "Invalid message format (details hidden unless explicitly enabled)"
return JSONResponse(status_code=400, content={"error": message})
def main() -> None:
"""Run the MCP server."""
import uvicorn
settings = get_settings()
uvicorn.run(
"aegis_gitea_mcp.server:app",
host=settings.mcp_host,
port=settings.mcp_port,
log_level=settings.log_level.lower(),
reload=False,
)
if __name__ == "__main__":
main()