595 lines
20 KiB
Python
595 lines
20 KiB
Python
"""Main MCP server implementation with hardened security controls."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
|
from typing import Any
|
|
|
|
from fastapi import FastAPI, HTTPException, Request, Response
|
|
from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse
|
|
from pydantic import BaseModel, Field, ValidationError
|
|
|
|
from aegis_gitea_mcp.audit import get_audit_logger
|
|
from aegis_gitea_mcp.auth import get_validator
|
|
from aegis_gitea_mcp.automation import AutomationError, AutomationManager
|
|
from aegis_gitea_mcp.config import get_settings
|
|
from aegis_gitea_mcp.gitea_client import (
|
|
GiteaAuthenticationError,
|
|
GiteaAuthorizationError,
|
|
GiteaClient,
|
|
)
|
|
from aegis_gitea_mcp.logging_utils import configure_logging
|
|
from aegis_gitea_mcp.mcp_protocol import (
|
|
AVAILABLE_TOOLS,
|
|
MCPListToolsResponse,
|
|
MCPToolCallRequest,
|
|
MCPToolCallResponse,
|
|
get_tool_by_name,
|
|
)
|
|
from aegis_gitea_mcp.observability import get_metrics_registry, monotonic_seconds
|
|
from aegis_gitea_mcp.policy import PolicyError, get_policy_engine
|
|
from aegis_gitea_mcp.rate_limit import get_rate_limiter
|
|
from aegis_gitea_mcp.request_context import set_request_id
|
|
from aegis_gitea_mcp.security import sanitize_data
|
|
from aegis_gitea_mcp.tools.arguments import extract_repository, extract_target_path
|
|
from aegis_gitea_mcp.tools.read_tools import (
|
|
compare_refs_tool,
|
|
get_commit_diff_tool,
|
|
get_issue_tool,
|
|
get_pull_request_tool,
|
|
list_commits_tool,
|
|
list_issues_tool,
|
|
list_labels_tool,
|
|
list_pull_requests_tool,
|
|
list_releases_tool,
|
|
list_tags_tool,
|
|
search_code_tool,
|
|
)
|
|
from aegis_gitea_mcp.tools.repository import (
|
|
get_file_contents_tool,
|
|
get_file_tree_tool,
|
|
get_repository_info_tool,
|
|
list_repositories_tool,
|
|
)
|
|
from aegis_gitea_mcp.tools.write_tools import (
|
|
add_labels_tool,
|
|
assign_issue_tool,
|
|
create_issue_comment_tool,
|
|
create_issue_tool,
|
|
create_pr_comment_tool,
|
|
update_issue_tool,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
app = FastAPI(
|
|
title="AegisGitea MCP Server",
|
|
description="Security-first MCP server for controlled AI access to self-hosted Gitea",
|
|
version="0.2.0",
|
|
)
|
|
|
|
|
|
class AutomationWebhookRequest(BaseModel):
|
|
"""Request body for automation webhook ingestion."""
|
|
|
|
event_type: str = Field(..., min_length=1, max_length=128)
|
|
payload: dict[str, Any] = Field(default_factory=dict)
|
|
repository: str | None = Field(default=None)
|
|
|
|
|
|
class AutomationJobRequest(BaseModel):
|
|
"""Request body for automation job execution."""
|
|
|
|
job_name: str = Field(..., min_length=1, max_length=128)
|
|
owner: str = Field(..., min_length=1, max_length=100)
|
|
repo: str = Field(..., min_length=1, max_length=100)
|
|
finding_title: str | None = Field(default=None, max_length=256)
|
|
finding_body: str | None = Field(default=None, max_length=10_000)
|
|
|
|
|
|
ToolHandler = Callable[[GiteaClient, dict[str, Any]], Awaitable[dict[str, Any]]]
|
|
|
|
TOOL_HANDLERS: dict[str, ToolHandler] = {
|
|
# Baseline read tools
|
|
"list_repositories": list_repositories_tool,
|
|
"get_repository_info": get_repository_info_tool,
|
|
"get_file_tree": get_file_tree_tool,
|
|
"get_file_contents": get_file_contents_tool,
|
|
# Expanded read tools
|
|
"search_code": search_code_tool,
|
|
"list_commits": list_commits_tool,
|
|
"get_commit_diff": get_commit_diff_tool,
|
|
"compare_refs": compare_refs_tool,
|
|
"list_issues": list_issues_tool,
|
|
"get_issue": get_issue_tool,
|
|
"list_pull_requests": list_pull_requests_tool,
|
|
"get_pull_request": get_pull_request_tool,
|
|
"list_labels": list_labels_tool,
|
|
"list_tags": list_tags_tool,
|
|
"list_releases": list_releases_tool,
|
|
# Write-mode tools
|
|
"create_issue": create_issue_tool,
|
|
"update_issue": update_issue_tool,
|
|
"create_issue_comment": create_issue_comment_tool,
|
|
"create_pr_comment": create_pr_comment_tool,
|
|
"add_labels": add_labels_tool,
|
|
"assign_issue": assign_issue_tool,
|
|
}
|
|
|
|
|
|
@app.middleware("http")
|
|
async def request_context_middleware(
|
|
request: Request,
|
|
call_next: Callable[[Request], Awaitable[Response]],
|
|
) -> Response:
|
|
"""Attach request correlation context and collect request metrics."""
|
|
request_id = request.headers.get("x-request-id") or str(uuid.uuid4())
|
|
set_request_id(request_id)
|
|
request.state.request_id = request_id
|
|
|
|
started_at = monotonic_seconds()
|
|
status_code = 500
|
|
|
|
try:
|
|
response = await call_next(request)
|
|
status_code = response.status_code
|
|
response.headers["X-Request-ID"] = request_id
|
|
return response
|
|
finally:
|
|
duration = max(monotonic_seconds() - started_at, 0.0)
|
|
logger.debug(
|
|
"request_completed",
|
|
extra={
|
|
"method": request.method,
|
|
"path": request.url.path,
|
|
"duration_seconds": duration,
|
|
"status_code": status_code,
|
|
},
|
|
)
|
|
metrics = get_metrics_registry()
|
|
metrics.record_http_request(request.method, request.url.path, status_code)
|
|
|
|
|
|
@app.middleware("http")
|
|
async def authenticate_and_rate_limit(
|
|
request: Request,
|
|
call_next: Callable[[Request], Awaitable[Response]],
|
|
) -> Response:
|
|
"""Apply rate-limiting and authentication for MCP endpoints."""
|
|
settings = get_settings()
|
|
|
|
if request.url.path in {"/", "/health"}:
|
|
return await call_next(request)
|
|
|
|
if request.url.path == "/metrics" and settings.metrics_enabled:
|
|
# Metrics endpoint is intentionally left unauthenticated for pull-based scraping.
|
|
return await call_next(request)
|
|
|
|
if not (request.url.path.startswith("/mcp/") or request.url.path.startswith("/automation/")):
|
|
return await call_next(request)
|
|
|
|
validator = get_validator()
|
|
limiter = get_rate_limiter()
|
|
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
user_agent = request.headers.get("user-agent", "unknown")
|
|
|
|
auth_header = request.headers.get("authorization")
|
|
api_key = validator.extract_bearer_token(auth_header)
|
|
if not api_key and request.url.path in {"/mcp/tool/call", "/mcp/sse"}:
|
|
api_key = request.query_params.get("api_key")
|
|
|
|
rate_limit = limiter.check(client_ip=client_ip, token=api_key)
|
|
if not rate_limit.allowed:
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"error": "Rate limit exceeded",
|
|
"message": rate_limit.reason,
|
|
"request_id": getattr(request.state, "request_id", "-"),
|
|
},
|
|
)
|
|
|
|
# Mixed mode: tool discovery remains public to preserve MCP client compatibility.
|
|
if request.url.path == "/mcp/tools":
|
|
return await call_next(request)
|
|
|
|
is_valid, error_message = validator.validate_api_key(api_key, client_ip, user_agent)
|
|
if not is_valid:
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={
|
|
"error": "Authentication failed",
|
|
"message": error_message,
|
|
"detail": "Provide Authorization: Bearer <api-key> or ?api_key=<api-key>",
|
|
"request_id": getattr(request.state, "request_id", "-"),
|
|
},
|
|
)
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event() -> None:
|
|
"""Initialize server state on startup."""
|
|
settings = get_settings()
|
|
configure_logging(settings.log_level)
|
|
|
|
logger.info("server_starting")
|
|
logger.info(
|
|
"server_configuration",
|
|
extra={
|
|
"host": settings.mcp_host,
|
|
"port": settings.mcp_port,
|
|
"gitea_url": settings.gitea_base_url,
|
|
"auth_enabled": settings.auth_enabled,
|
|
"write_mode": settings.write_mode,
|
|
"metrics_enabled": settings.metrics_enabled,
|
|
},
|
|
)
|
|
|
|
# Fail-fast policy parse errors at startup.
|
|
try:
|
|
_ = get_policy_engine()
|
|
except PolicyError:
|
|
logger.error("policy_load_failed")
|
|
raise
|
|
|
|
if settings.startup_validate_gitea and settings.environment != "test":
|
|
try:
|
|
async with GiteaClient() as gitea:
|
|
user = await gitea.get_current_user()
|
|
logger.info("gitea_connected", extra={"bot_user": user.get("login", "unknown")})
|
|
except GiteaAuthenticationError as exc:
|
|
logger.error("gitea_connection_failed_authentication")
|
|
raise RuntimeError(
|
|
"Startup validation failed: Gitea authentication was rejected. Check GITEA_TOKEN."
|
|
) from exc
|
|
except GiteaAuthorizationError as exc:
|
|
logger.error("gitea_connection_failed_authorization")
|
|
raise RuntimeError(
|
|
"Startup validation failed: Gitea token lacks permission for /api/v1/user."
|
|
) from exc
|
|
except Exception as exc:
|
|
logger.error("gitea_connection_failed")
|
|
raise RuntimeError("Startup validation failed: unable to connect to Gitea.") from exc
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown_event() -> None:
|
|
"""Log server shutdown event."""
|
|
logger.info("server_stopping")
|
|
|
|
|
|
@app.get("/")
|
|
async def root() -> dict[str, Any]:
|
|
"""Root endpoint with server metadata."""
|
|
return {
|
|
"name": "AegisGitea MCP Server",
|
|
"version": "0.2.0",
|
|
"status": "running",
|
|
"mcp_version": "1.0",
|
|
}
|
|
|
|
|
|
@app.get("/health")
|
|
async def health() -> dict[str, str]:
|
|
"""Health check endpoint."""
|
|
return {"status": "healthy"}
|
|
|
|
|
|
@app.get("/metrics")
|
|
async def metrics() -> PlainTextResponse:
|
|
"""Prometheus-compatible metrics endpoint."""
|
|
settings = get_settings()
|
|
if not settings.metrics_enabled:
|
|
raise HTTPException(status_code=404, detail="Metrics endpoint disabled")
|
|
data = get_metrics_registry().render_prometheus()
|
|
return PlainTextResponse(content=data, media_type="text/plain; version=0.0.4")
|
|
|
|
|
|
@app.post("/automation/webhook")
|
|
async def automation_webhook(request: AutomationWebhookRequest) -> JSONResponse:
|
|
"""Ingest policy-controlled automation webhooks."""
|
|
manager = AutomationManager()
|
|
try:
|
|
result = await manager.handle_webhook(
|
|
event_type=request.event_type,
|
|
payload=request.payload,
|
|
repository=request.repository,
|
|
)
|
|
return JSONResponse(content={"success": True, "result": result})
|
|
except AutomationError as exc:
|
|
raise HTTPException(status_code=403, detail=str(exc)) from exc
|
|
|
|
|
|
@app.post("/automation/jobs/run")
|
|
async def automation_run_job(request: AutomationJobRequest) -> JSONResponse:
|
|
"""Execute a policy-controlled automation job for a repository."""
|
|
manager = AutomationManager()
|
|
try:
|
|
result = await manager.run_job(
|
|
job_name=request.job_name,
|
|
owner=request.owner,
|
|
repo=request.repo,
|
|
finding_title=request.finding_title,
|
|
finding_body=request.finding_body,
|
|
)
|
|
return JSONResponse(content={"success": True, "result": result})
|
|
except AutomationError as exc:
|
|
raise HTTPException(status_code=403, detail=str(exc)) from exc
|
|
|
|
|
|
@app.get("/mcp/tools")
|
|
async def list_tools() -> JSONResponse:
|
|
"""List all available MCP tools."""
|
|
response = MCPListToolsResponse(tools=AVAILABLE_TOOLS)
|
|
return JSONResponse(content=response.model_dump(by_alias=True))
|
|
|
|
|
|
async def _execute_tool_call(
|
|
tool_name: str, arguments: dict[str, Any], correlation_id: str
|
|
) -> dict[str, Any]:
|
|
"""Execute tool call with policy checks and standardized response sanitization."""
|
|
settings = get_settings()
|
|
audit = get_audit_logger()
|
|
metrics = get_metrics_registry()
|
|
|
|
tool_def = get_tool_by_name(tool_name)
|
|
if not tool_def:
|
|
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
|
|
|
|
handler = TOOL_HANDLERS.get(tool_name)
|
|
if not handler:
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Tool '{tool_name}' has no handler implementation"
|
|
)
|
|
|
|
repository = extract_repository(arguments)
|
|
target_path = extract_target_path(arguments)
|
|
decision = get_policy_engine().authorize(
|
|
tool_name=tool_name,
|
|
is_write=tool_def.write_operation,
|
|
repository=repository,
|
|
target_path=target_path,
|
|
)
|
|
if not decision.allowed:
|
|
audit.log_access_denied(
|
|
tool_name=tool_name,
|
|
repository=repository,
|
|
reason=decision.reason,
|
|
correlation_id=correlation_id,
|
|
)
|
|
raise HTTPException(status_code=403, detail=f"Policy denied request: {decision.reason}")
|
|
|
|
started_at = monotonic_seconds()
|
|
status = "error"
|
|
|
|
try:
|
|
async with GiteaClient() as gitea:
|
|
result = await handler(gitea, arguments)
|
|
|
|
if settings.secret_detection_mode != "off":
|
|
# Security decision: sanitize outbound payloads to prevent accidental secret exfiltration.
|
|
result = sanitize_data(result, mode=settings.secret_detection_mode)
|
|
|
|
status = "success"
|
|
return result
|
|
finally:
|
|
duration = max(monotonic_seconds() - started_at, 0.0)
|
|
metrics.record_tool_call(tool_name, status, duration)
|
|
|
|
|
|
@app.post("/mcp/tool/call")
|
|
async def call_tool(request: MCPToolCallRequest) -> JSONResponse:
|
|
"""Execute an MCP tool call."""
|
|
settings = get_settings()
|
|
audit = get_audit_logger()
|
|
|
|
correlation_id = request.correlation_id or audit.log_tool_invocation(
|
|
tool_name=request.tool,
|
|
params=request.arguments,
|
|
)
|
|
|
|
try:
|
|
result = await _execute_tool_call(request.tool, request.arguments, correlation_id)
|
|
audit.log_tool_invocation(
|
|
tool_name=request.tool,
|
|
correlation_id=correlation_id,
|
|
result_status="success",
|
|
)
|
|
return JSONResponse(
|
|
content=MCPToolCallResponse(
|
|
success=True,
|
|
result=result,
|
|
correlation_id=correlation_id,
|
|
).model_dump()
|
|
)
|
|
|
|
except HTTPException as exc:
|
|
audit.log_tool_invocation(
|
|
tool_name=request.tool,
|
|
correlation_id=correlation_id,
|
|
result_status="error",
|
|
error=str(exc.detail),
|
|
)
|
|
raise
|
|
|
|
except ValidationError as exc:
|
|
error_message = "Invalid tool arguments"
|
|
if settings.expose_error_details:
|
|
error_message = f"{error_message}: {exc}"
|
|
|
|
audit.log_tool_invocation(
|
|
tool_name=request.tool,
|
|
correlation_id=correlation_id,
|
|
result_status="error",
|
|
error="validation_error",
|
|
)
|
|
raise HTTPException(status_code=400, detail=error_message) from exc
|
|
|
|
except Exception:
|
|
# Security decision: do not leak stack traces or raw exception messages.
|
|
error_message = "Internal server error"
|
|
if settings.expose_error_details:
|
|
error_message = "Internal server error (details hidden unless explicitly enabled)"
|
|
|
|
audit.log_tool_invocation(
|
|
tool_name=request.tool,
|
|
correlation_id=correlation_id,
|
|
result_status="error",
|
|
error="internal_error",
|
|
)
|
|
logger.exception("tool_execution_failed")
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content=MCPToolCallResponse(
|
|
success=False,
|
|
error=error_message,
|
|
correlation_id=correlation_id,
|
|
).model_dump(),
|
|
)
|
|
|
|
|
|
@app.get("/mcp/sse")
|
|
async def sse_endpoint(request: Request) -> StreamingResponse:
|
|
"""Server-Sent Events endpoint for MCP transport."""
|
|
|
|
async def event_stream() -> AsyncGenerator[str, None]:
|
|
yield (
|
|
"data: "
|
|
+ json.dumps(
|
|
{"event": "connected", "server": "AegisGitea MCP", "version": "0.2.0"},
|
|
separators=(",", ":"),
|
|
)
|
|
+ "\n\n"
|
|
)
|
|
|
|
try:
|
|
while True:
|
|
if await request.is_disconnected():
|
|
break
|
|
yield 'data: {"event":"heartbeat"}\n\n'
|
|
await asyncio.sleep(30)
|
|
except Exception:
|
|
logger.exception("sse_stream_error")
|
|
|
|
return StreamingResponse(
|
|
event_stream(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
|
|
@app.post("/mcp/sse")
|
|
async def sse_message_handler(request: Request) -> JSONResponse:
|
|
"""Handle POST messages for MCP SSE transport."""
|
|
settings = get_settings()
|
|
audit = get_audit_logger()
|
|
|
|
try:
|
|
body = await request.json()
|
|
message_type = body.get("type") or body.get("method")
|
|
message_id = body.get("id")
|
|
|
|
if message_type == "initialize":
|
|
return JSONResponse(
|
|
content={
|
|
"jsonrpc": "2.0",
|
|
"id": message_id,
|
|
"result": {
|
|
"protocolVersion": "2024-11-05",
|
|
"capabilities": {"tools": {}},
|
|
"serverInfo": {"name": "AegisGitea MCP", "version": "0.2.0"},
|
|
},
|
|
}
|
|
)
|
|
|
|
if message_type == "tools/list":
|
|
response = MCPListToolsResponse(tools=AVAILABLE_TOOLS)
|
|
return JSONResponse(
|
|
content={
|
|
"jsonrpc": "2.0",
|
|
"id": message_id,
|
|
"result": response.model_dump(by_alias=True),
|
|
}
|
|
)
|
|
|
|
if message_type == "tools/call":
|
|
tool_name = body.get("params", {}).get("name")
|
|
tool_args = body.get("params", {}).get("arguments", {})
|
|
|
|
correlation_id = audit.log_tool_invocation(tool_name=tool_name, params=tool_args)
|
|
try:
|
|
result = await _execute_tool_call(str(tool_name), tool_args, correlation_id)
|
|
audit.log_tool_invocation(
|
|
tool_name=str(tool_name),
|
|
correlation_id=correlation_id,
|
|
result_status="success",
|
|
)
|
|
return JSONResponse(
|
|
content={
|
|
"jsonrpc": "2.0",
|
|
"id": message_id,
|
|
"result": {"content": [{"type": "text", "text": json.dumps(result)}]},
|
|
}
|
|
)
|
|
except Exception as exc:
|
|
audit.log_tool_invocation(
|
|
tool_name=str(tool_name),
|
|
correlation_id=correlation_id,
|
|
result_status="error",
|
|
error=str(exc),
|
|
)
|
|
message = "Internal server error"
|
|
if settings.expose_error_details:
|
|
message = str(exc)
|
|
return JSONResponse(
|
|
content={
|
|
"jsonrpc": "2.0",
|
|
"id": message_id,
|
|
"error": {"code": -32603, "message": message},
|
|
}
|
|
)
|
|
|
|
if isinstance(message_type, str) and message_type.startswith("notifications/"):
|
|
return JSONResponse(content={})
|
|
|
|
return JSONResponse(
|
|
content={"jsonrpc": "2.0", "id": message_id, "result": {"acknowledged": True}}
|
|
)
|
|
|
|
except Exception:
|
|
logger.exception("sse_message_handler_error")
|
|
message = "Invalid message format"
|
|
if settings.expose_error_details:
|
|
message = "Invalid message format (details hidden unless explicitly enabled)"
|
|
return JSONResponse(status_code=400, content={"error": message})
|
|
|
|
|
|
def main() -> None:
|
|
"""Run the MCP server."""
|
|
import uvicorn
|
|
|
|
settings = get_settings()
|
|
|
|
uvicorn.run(
|
|
"aegis_gitea_mcp.server:app",
|
|
host=settings.mcp_host,
|
|
port=settings.mcp_port,
|
|
log_level=settings.log_level.lower(),
|
|
reload=False,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|