feat: harden gateway with policy engine, secure tools, and governance docs

This commit is contained in:
2026-02-14 16:05:56 +01:00
parent e17d34e6d7
commit 5969892af3
55 changed files with 4711 additions and 1587 deletions

View File

@@ -1,16 +1,24 @@
"""Main MCP server implementation with FastAPI and SSE support."""
"""Main MCP server implementation with hardened security controls."""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any, Dict
import uuid
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import ValidationError
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse
from pydantic import BaseModel, Field, ValidationError
from aegis_gitea_mcp.audit import get_audit_logger
from aegis_gitea_mcp.auth import get_validator
from aegis_gitea_mcp.automation import AutomationError, AutomationManager
from aegis_gitea_mcp.config import get_settings
from aegis_gitea_mcp.gitea_client import GiteaClient
from aegis_gitea_mcp.logging_utils import configure_logging
from aegis_gitea_mcp.mcp_protocol import (
AVAILABLE_TOOLS,
MCPListToolsResponse,
@@ -18,276 +26,443 @@ from aegis_gitea_mcp.mcp_protocol import (
MCPToolCallResponse,
get_tool_by_name,
)
from aegis_gitea_mcp.observability import get_metrics_registry, monotonic_seconds
from aegis_gitea_mcp.policy import PolicyError, get_policy_engine
from aegis_gitea_mcp.rate_limit import get_rate_limiter
from aegis_gitea_mcp.request_context import set_request_id
from aegis_gitea_mcp.security import sanitize_data
from aegis_gitea_mcp.tools.arguments import extract_repository, extract_target_path
from aegis_gitea_mcp.tools.read_tools import (
compare_refs_tool,
get_commit_diff_tool,
get_issue_tool,
get_pull_request_tool,
list_commits_tool,
list_issues_tool,
list_labels_tool,
list_pull_requests_tool,
list_releases_tool,
list_tags_tool,
search_code_tool,
)
from aegis_gitea_mcp.tools.repository import (
get_file_contents_tool,
get_file_tree_tool,
get_repository_info_tool,
list_repositories_tool,
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
from aegis_gitea_mcp.tools.write_tools import (
add_labels_tool,
assign_issue_tool,
create_issue_comment_tool,
create_issue_tool,
create_pr_comment_tool,
update_issue_tool,
)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="AegisGitea MCP Server",
description="Security-first MCP server for controlled AI access to self-hosted Gitea",
version="0.1.0",
version="0.2.0",
)
# Global settings and audit logger
# Note: access settings/audit logger dynamically to support test resets.
class AutomationWebhookRequest(BaseModel):
"""Request body for automation webhook ingestion."""
event_type: str = Field(..., min_length=1, max_length=128)
payload: dict[str, Any] = Field(default_factory=dict)
repository: str | None = Field(default=None)
# Tool dispatcher mapping
TOOL_HANDLERS = {
class AutomationJobRequest(BaseModel):
"""Request body for automation job execution."""
job_name: str = Field(..., min_length=1, max_length=128)
owner: str = Field(..., min_length=1, max_length=100)
repo: str = Field(..., min_length=1, max_length=100)
finding_title: str | None = Field(default=None, max_length=256)
finding_body: str | None = Field(default=None, max_length=10_000)
ToolHandler = Callable[[GiteaClient, dict[str, Any]], Awaitable[dict[str, Any]]]
TOOL_HANDLERS: dict[str, ToolHandler] = {
# Baseline read tools
"list_repositories": list_repositories_tool,
"get_repository_info": get_repository_info_tool,
"get_file_tree": get_file_tree_tool,
"get_file_contents": get_file_contents_tool,
# Expanded read tools
"search_code": search_code_tool,
"list_commits": list_commits_tool,
"get_commit_diff": get_commit_diff_tool,
"compare_refs": compare_refs_tool,
"list_issues": list_issues_tool,
"get_issue": get_issue_tool,
"list_pull_requests": list_pull_requests_tool,
"get_pull_request": get_pull_request_tool,
"list_labels": list_labels_tool,
"list_tags": list_tags_tool,
"list_releases": list_releases_tool,
# Write-mode tools
"create_issue": create_issue_tool,
"update_issue": update_issue_tool,
"create_issue_comment": create_issue_comment_tool,
"create_pr_comment": create_pr_comment_tool,
"add_labels": add_labels_tool,
"assign_issue": assign_issue_tool,
}
# Authentication middleware
@app.middleware("http")
async def authenticate_request(request: Request, call_next):
"""Authenticate all requests except health checks and root.
async def request_context_middleware(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
"""Attach request correlation context and collect request metrics."""
request_id = request.headers.get("x-request-id") or str(uuid.uuid4())
set_request_id(request_id)
request.state.request_id = request_id
Supports Mixed authentication mode where:
- /mcp/tools (list tools) is publicly accessible (No Auth)
- /mcp/tool/call (execute tools) requires authentication
- /mcp/sse requires authentication
"""
# Skip authentication for health check and root endpoints
if request.url.path in ["/", "/health"]:
started_at = monotonic_seconds()
status_code = 500
try:
response = await call_next(request)
status_code = response.status_code
response.headers["X-Request-ID"] = request_id
return response
finally:
duration = max(monotonic_seconds() - started_at, 0.0)
logger.debug(
"request_completed",
extra={
"method": request.method,
"path": request.url.path,
"duration_seconds": duration,
"status_code": status_code,
},
)
metrics = get_metrics_registry()
metrics.record_http_request(request.method, request.url.path, status_code)
@app.middleware("http")
async def authenticate_and_rate_limit(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
"""Apply rate-limiting and authentication for MCP endpoints."""
settings = get_settings()
if request.url.path in {"/", "/health"}:
return await call_next(request)
# Only authenticate MCP endpoints
if not request.url.path.startswith("/mcp/"):
if request.url.path == "/metrics" and settings.metrics_enabled:
# Metrics endpoint is intentionally left unauthenticated for pull-based scraping.
return await call_next(request)
# Mixed mode: allow /mcp/tools without authentication (for ChatGPT discovery)
if request.url.path == "/mcp/tools":
if not (request.url.path.startswith("/mcp/") or request.url.path.startswith("/automation/")):
return await call_next(request)
# Extract client information
validator = get_validator()
limiter = get_rate_limiter()
client_ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
# Get validator instance (supports test resets)
validator = get_validator()
# Extract Authorization header
auth_header = request.headers.get("authorization")
api_key = validator.extract_bearer_token(auth_header)
# Fallback: allow API key via query parameter only for MCP endpoints
if not api_key and request.url.path in {"/mcp/tool/call", "/mcp/sse"}:
api_key = request.query_params.get("api_key")
# Validate API key
is_valid, error_message = validator.validate_api_key(api_key, client_ip, user_agent)
rate_limit = limiter.check(client_ip=client_ip, token=api_key)
if not rate_limit.allowed:
return JSONResponse(
status_code=429,
content={
"error": "Rate limit exceeded",
"message": rate_limit.reason,
"request_id": getattr(request.state, "request_id", "-"),
},
)
# Mixed mode: tool discovery remains public to preserve MCP client compatibility.
if request.url.path == "/mcp/tools":
return await call_next(request)
is_valid, error_message = validator.validate_api_key(api_key, client_ip, user_agent)
if not is_valid:
return JSONResponse(
status_code=401,
content={
"error": "Authentication failed",
"message": error_message,
"detail": (
"Provide a valid API key via Authorization header (Bearer <api-key>) "
"or ?api_key=<api-key> query parameter"
),
"detail": "Provide Authorization: Bearer <api-key> or ?api_key=<api-key>",
"request_id": getattr(request.state, "request_id", "-"),
},
)
# Authentication successful - continue to endpoint
response = await call_next(request)
return response
return await call_next(request)
@app.on_event("startup")
async def startup_event() -> None:
"""Initialize server on startup."""
"""Initialize server state on startup."""
settings = get_settings()
logger.info(f"Starting AegisGitea MCP Server on {settings.mcp_host}:{settings.mcp_port}")
logger.info(f"Connected to Gitea instance: {settings.gitea_base_url}")
logger.info(f"Audit logging enabled: {settings.audit_log_path}")
configure_logging(settings.log_level)
# Log authentication status
if settings.auth_enabled:
key_count = len(settings.mcp_api_keys)
logger.info(f"API key authentication ENABLED ({key_count} key(s) configured)")
else:
logger.warning("API key authentication DISABLED - server is open to all requests!")
logger.info("server_starting")
logger.info(
"server_configuration",
extra={
"host": settings.mcp_host,
"port": settings.mcp_port,
"gitea_url": settings.gitea_base_url,
"auth_enabled": settings.auth_enabled,
"write_mode": settings.write_mode,
"metrics_enabled": settings.metrics_enabled,
},
)
# Test Gitea connection
# Fail-fast policy parse errors at startup.
try:
async with GiteaClient() as gitea:
user = await gitea.get_current_user()
logger.info(f"Authenticated as bot user: {user.get('login', 'unknown')}")
except Exception as e:
logger.error(f"Failed to connect to Gitea: {e}")
_ = get_policy_engine()
except PolicyError:
logger.error("policy_load_failed")
raise
if settings.startup_validate_gitea and settings.environment != "test":
try:
async with GiteaClient() as gitea:
user = await gitea.get_current_user()
logger.info("gitea_connected", extra={"bot_user": user.get("login", "unknown")})
except Exception:
logger.error("gitea_connection_failed")
raise
@app.on_event("shutdown")
async def shutdown_event() -> None:
"""Cleanup on server shutdown."""
logger.info("Shutting down AegisGitea MCP Server")
"""Log server shutdown event."""
logger.info("server_stopping")
@app.get("/")
async def root() -> Dict[str, Any]:
"""Root endpoint with server information."""
async def root() -> dict[str, Any]:
"""Root endpoint with server metadata."""
return {
"name": "AegisGitea MCP Server",
"version": "0.1.0",
"version": "0.2.0",
"status": "running",
"mcp_version": "1.0",
}
@app.get("/health")
async def health() -> Dict[str, str]:
async def health() -> dict[str, str]:
"""Health check endpoint."""
return {"status": "healthy"}
@app.get("/metrics")
async def metrics() -> PlainTextResponse:
"""Prometheus-compatible metrics endpoint."""
settings = get_settings()
if not settings.metrics_enabled:
raise HTTPException(status_code=404, detail="Metrics endpoint disabled")
data = get_metrics_registry().render_prometheus()
return PlainTextResponse(content=data, media_type="text/plain; version=0.0.4")
@app.post("/automation/webhook")
async def automation_webhook(request: AutomationWebhookRequest) -> JSONResponse:
"""Ingest policy-controlled automation webhooks."""
manager = AutomationManager()
try:
result = await manager.handle_webhook(
event_type=request.event_type,
payload=request.payload,
repository=request.repository,
)
return JSONResponse(content={"success": True, "result": result})
except AutomationError as exc:
raise HTTPException(status_code=403, detail=str(exc)) from exc
@app.post("/automation/jobs/run")
async def automation_run_job(request: AutomationJobRequest) -> JSONResponse:
"""Execute a policy-controlled automation job for a repository."""
manager = AutomationManager()
try:
result = await manager.run_job(
job_name=request.job_name,
owner=request.owner,
repo=request.repo,
finding_title=request.finding_title,
finding_body=request.finding_body,
)
return JSONResponse(content={"success": True, "result": result})
except AutomationError as exc:
raise HTTPException(status_code=403, detail=str(exc)) from exc
@app.get("/mcp/tools")
async def list_tools() -> JSONResponse:
"""List all available MCP tools.
Returns:
JSON response with list of tool definitions
"""
"""List all available MCP tools."""
response = MCPListToolsResponse(tools=AVAILABLE_TOOLS)
return JSONResponse(content=response.model_dump(by_alias=True))
return JSONResponse(content=response.model_dump())
async def _execute_tool_call(
tool_name: str, arguments: dict[str, Any], correlation_id: str
) -> dict[str, Any]:
"""Execute tool call with policy checks and standardized response sanitization."""
settings = get_settings()
audit = get_audit_logger()
metrics = get_metrics_registry()
tool_def = get_tool_by_name(tool_name)
if not tool_def:
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
handler = TOOL_HANDLERS.get(tool_name)
if not handler:
raise HTTPException(
status_code=500, detail=f"Tool '{tool_name}' has no handler implementation"
)
repository = extract_repository(arguments)
target_path = extract_target_path(arguments)
decision = get_policy_engine().authorize(
tool_name=tool_name,
is_write=tool_def.write_operation,
repository=repository,
target_path=target_path,
)
if not decision.allowed:
audit.log_access_denied(
tool_name=tool_name,
repository=repository,
reason=decision.reason,
correlation_id=correlation_id,
)
raise HTTPException(status_code=403, detail=f"Policy denied request: {decision.reason}")
started_at = monotonic_seconds()
status = "error"
try:
async with GiteaClient() as gitea:
result = await handler(gitea, arguments)
if settings.secret_detection_mode != "off":
# Security decision: sanitize outbound payloads to prevent accidental secret exfiltration.
result = sanitize_data(result, mode=settings.secret_detection_mode)
status = "success"
return result
finally:
duration = max(monotonic_seconds() - started_at, 0.0)
metrics.record_tool_call(tool_name, status, duration)
@app.post("/mcp/tool/call")
async def call_tool(request: MCPToolCallRequest) -> JSONResponse:
"""Execute an MCP tool call.
Args:
request: Tool call request with tool name and arguments
Returns:
JSON response with tool execution result
"""
"""Execute an MCP tool call."""
settings = get_settings()
audit = get_audit_logger()
correlation_id = request.correlation_id or audit.log_tool_invocation(
tool_name=request.tool,
params=request.arguments,
)
try:
# Validate tool exists
tool_def = get_tool_by_name(request.tool)
if not tool_def:
error_msg = f"Tool '{request.tool}' not found"
audit.log_tool_invocation(
tool_name=request.tool,
correlation_id=correlation_id,
result_status="error",
error=error_msg,
)
raise HTTPException(status_code=404, detail=error_msg)
# Get tool handler
handler = TOOL_HANDLERS.get(request.tool)
if not handler:
error_msg = f"Tool '{request.tool}' has no handler implementation"
audit.log_tool_invocation(
tool_name=request.tool,
correlation_id=correlation_id,
result_status="error",
error=error_msg,
)
raise HTTPException(status_code=500, detail=error_msg)
# Execute tool with Gitea client
async with GiteaClient() as gitea:
result = await handler(gitea, request.arguments)
result = await _execute_tool_call(request.tool, request.arguments, correlation_id)
audit.log_tool_invocation(
tool_name=request.tool,
correlation_id=correlation_id,
result_status="success",
)
response = MCPToolCallResponse(
success=True,
result=result,
correlation_id=correlation_id,
return JSONResponse(
content=MCPToolCallResponse(
success=True,
result=result,
correlation_id=correlation_id,
).model_dump()
)
return JSONResponse(content=response.model_dump())
except HTTPException:
# Re-raise HTTP exceptions (like 404) without catching them
except HTTPException as exc:
audit.log_tool_invocation(
tool_name=request.tool,
correlation_id=correlation_id,
result_status="error",
error=str(exc.detail),
)
raise
except ValidationError as e:
error_msg = f"Invalid arguments: {str(e)}"
audit.log_tool_invocation(
tool_name=request.tool,
correlation_id=correlation_id,
result_status="error",
error=error_msg,
)
raise HTTPException(status_code=400, detail=error_msg)
except ValidationError as exc:
error_message = "Invalid tool arguments"
if settings.expose_error_details:
error_message = f"{error_message}: {exc}"
except Exception as e:
error_msg = str(e)
audit.log_tool_invocation(
tool_name=request.tool,
correlation_id=correlation_id,
result_status="error",
error=error_msg,
error="validation_error",
)
response = MCPToolCallResponse(
success=False,
error=error_msg,
raise HTTPException(status_code=400, detail=error_message) from exc
except Exception:
# Security decision: do not leak stack traces or raw exception messages.
error_message = "Internal server error"
if settings.expose_error_details:
error_message = "Internal server error (details hidden unless explicitly enabled)"
audit.log_tool_invocation(
tool_name=request.tool,
correlation_id=correlation_id,
result_status="error",
error="internal_error",
)
logger.exception("tool_execution_failed")
return JSONResponse(
status_code=500,
content=MCPToolCallResponse(
success=False,
error=error_message,
correlation_id=correlation_id,
).model_dump(),
)
return JSONResponse(content=response.model_dump(), status_code=500)
@app.get("/mcp/sse")
async def sse_endpoint(request: Request) -> StreamingResponse:
"""Server-Sent Events endpoint for MCP protocol.
"""Server-Sent Events endpoint for MCP transport."""
This enables real-time communication with ChatGPT using SSE.
async def event_stream() -> AsyncGenerator[str, None]:
yield (
"data: "
+ json.dumps(
{"event": "connected", "server": "AegisGitea MCP", "version": "0.2.0"},
separators=(",", ":"),
)
+ "\n\n"
)
Returns:
Streaming SSE response
"""
async def event_stream():
"""Generate SSE events."""
# Send initial connection event
yield f"data: {{'event': 'connected', 'server': 'AegisGitea MCP', 'version': '0.1.0'}}\n\n"
# Keep connection alive
try:
while True:
if await request.is_disconnected():
break
# Heartbeat every 30 seconds
yield f"data: {{'event': 'heartbeat'}}\n\n"
# Wait for next heartbeat (in production, this would handle actual events)
import asyncio
yield 'data: {"event":"heartbeat"}\n\n'
await asyncio.sleep(30)
except Exception as e:
logger.error(f"SSE stream error: {e}")
except Exception:
logger.exception("sse_stream_error")
return StreamingResponse(
event_stream(),
@@ -302,21 +477,12 @@ async def sse_endpoint(request: Request) -> StreamingResponse:
@app.post("/mcp/sse")
async def sse_message_handler(request: Request) -> JSONResponse:
"""Handle POST messages from ChatGPT MCP client to SSE endpoint.
"""Handle POST messages for MCP SSE transport."""
settings = get_settings()
audit = get_audit_logger()
The MCP SSE transport uses:
- GET /mcp/sse for server-to-client streaming
- POST /mcp/sse for client-to-server messages
Returns:
JSON response acknowledging the message
"""
try:
audit = get_audit_logger()
body = await request.json()
logger.info(f"Received MCP message via SSE POST: {body}")
# Handle different message types
message_type = body.get("type") or body.get("method")
message_id = body.get("id")
@@ -328,87 +494,71 @@ async def sse_message_handler(request: Request) -> JSONResponse:
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"serverInfo": {"name": "AegisGitea MCP", "version": "0.1.0"},
"serverInfo": {"name": "AegisGitea MCP", "version": "0.2.0"},
},
}
)
elif message_type == "tools/list":
# Return the list of available tools
if message_type == "tools/list":
response = MCPListToolsResponse(tools=AVAILABLE_TOOLS)
return JSONResponse(
content={
"jsonrpc": "2.0",
"id": message_id,
"result": response.model_dump(by_alias=True),
"result": response.model_dump(),
}
)
elif message_type == "tools/call":
# Handle tool execution
if message_type == "tools/call":
tool_name = body.get("params", {}).get("name")
tool_args = body.get("params", {}).get("arguments", {})
correlation_id = audit.log_tool_invocation(
tool_name=tool_name,
params=tool_args,
)
correlation_id = audit.log_tool_invocation(tool_name=tool_name, params=tool_args)
try:
# Get tool handler
handler = TOOL_HANDLERS.get(tool_name)
if not handler:
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
# Execute tool with Gitea client
async with GiteaClient() as gitea:
result = await handler(gitea, tool_args)
result = await _execute_tool_call(str(tool_name), tool_args, correlation_id)
audit.log_tool_invocation(
tool_name=tool_name,
tool_name=str(tool_name),
correlation_id=correlation_id,
result_status="success",
)
return JSONResponse(
content={
"jsonrpc": "2.0",
"id": message_id,
"result": {"content": [{"type": "text", "text": str(result)}]},
"result": {"content": [{"type": "text", "text": json.dumps(result)}]},
}
)
except Exception as e:
error_msg = str(e)
except Exception as exc:
audit.log_tool_invocation(
tool_name=tool_name,
tool_name=str(tool_name),
correlation_id=correlation_id,
result_status="error",
error=error_msg,
error=str(exc),
)
message = "Internal server error"
if settings.expose_error_details:
message = str(exc)
return JSONResponse(
content={
"jsonrpc": "2.0",
"id": message_id,
"error": {"code": -32603, "message": error_msg},
"error": {"code": -32603, "message": message},
}
)
# Handle notifications (no response needed)
elif message_type and message_type.startswith("notifications/"):
logger.info(f"Received notification: {message_type}")
if isinstance(message_type, str) and message_type.startswith("notifications/"):
return JSONResponse(content={})
# Acknowledge other message types
return JSONResponse(
content={"jsonrpc": "2.0", "id": message_id, "result": {"acknowledged": True}}
)
except Exception as e:
logger.error(f"Error handling SSE POST message: {e}")
return JSONResponse(
status_code=400, content={"error": "Invalid message format", "detail": str(e)}
)
except Exception:
logger.exception("sse_message_handler_error")
message = "Invalid message format"
if settings.expose_error_details:
message = "Invalid message format (details hidden unless explicitly enabled)"
return JSONResponse(status_code=400, content={"error": message})
def main() -> None: