"""Main MCP server implementation with hardened security controls.""" from __future__ import annotations import asyncio import json import logging import uuid from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Any from fastapi import FastAPI, HTTPException, Request, Response from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse from pydantic import BaseModel, Field, ValidationError from aegis_gitea_mcp.audit import get_audit_logger from aegis_gitea_mcp.auth import get_validator from aegis_gitea_mcp.automation import AutomationError, AutomationManager from aegis_gitea_mcp.config import get_settings from aegis_gitea_mcp.gitea_client import ( GiteaAuthenticationError, GiteaAuthorizationError, GiteaClient, ) from aegis_gitea_mcp.logging_utils import configure_logging from aegis_gitea_mcp.mcp_protocol import ( AVAILABLE_TOOLS, MCPListToolsResponse, MCPToolCallRequest, MCPToolCallResponse, get_tool_by_name, ) from aegis_gitea_mcp.observability import get_metrics_registry, monotonic_seconds from aegis_gitea_mcp.policy import PolicyError, get_policy_engine from aegis_gitea_mcp.rate_limit import get_rate_limiter from aegis_gitea_mcp.request_context import set_request_id from aegis_gitea_mcp.security import sanitize_data from aegis_gitea_mcp.tools.arguments import extract_repository, extract_target_path from aegis_gitea_mcp.tools.read_tools import ( compare_refs_tool, get_commit_diff_tool, get_issue_tool, get_pull_request_tool, list_commits_tool, list_issues_tool, list_labels_tool, list_pull_requests_tool, list_releases_tool, list_tags_tool, search_code_tool, ) from aegis_gitea_mcp.tools.repository import ( get_file_contents_tool, get_file_tree_tool, get_repository_info_tool, list_repositories_tool, ) from aegis_gitea_mcp.tools.write_tools import ( add_labels_tool, assign_issue_tool, create_issue_comment_tool, create_issue_tool, create_pr_comment_tool, update_issue_tool, ) logger = logging.getLogger(__name__) app = FastAPI( title="AegisGitea MCP Server", description="Security-first MCP server for controlled AI access to self-hosted Gitea", version="0.2.0", ) class AutomationWebhookRequest(BaseModel): """Request body for automation webhook ingestion.""" event_type: str = Field(..., min_length=1, max_length=128) payload: dict[str, Any] = Field(default_factory=dict) repository: str | None = Field(default=None) class AutomationJobRequest(BaseModel): """Request body for automation job execution.""" job_name: str = Field(..., min_length=1, max_length=128) owner: str = Field(..., min_length=1, max_length=100) repo: str = Field(..., min_length=1, max_length=100) finding_title: str | None = Field(default=None, max_length=256) finding_body: str | None = Field(default=None, max_length=10_000) ToolHandler = Callable[[GiteaClient, dict[str, Any]], Awaitable[dict[str, Any]]] TOOL_HANDLERS: dict[str, ToolHandler] = { # Baseline read tools "list_repositories": list_repositories_tool, "get_repository_info": get_repository_info_tool, "get_file_tree": get_file_tree_tool, "get_file_contents": get_file_contents_tool, # Expanded read tools "search_code": search_code_tool, "list_commits": list_commits_tool, "get_commit_diff": get_commit_diff_tool, "compare_refs": compare_refs_tool, "list_issues": list_issues_tool, "get_issue": get_issue_tool, "list_pull_requests": list_pull_requests_tool, "get_pull_request": get_pull_request_tool, "list_labels": list_labels_tool, "list_tags": list_tags_tool, "list_releases": list_releases_tool, # Write-mode tools "create_issue": create_issue_tool, "update_issue": update_issue_tool, "create_issue_comment": create_issue_comment_tool, "create_pr_comment": create_pr_comment_tool, "add_labels": add_labels_tool, "assign_issue": assign_issue_tool, } @app.middleware("http") async def request_context_middleware( request: Request, call_next: Callable[[Request], Awaitable[Response]], ) -> Response: """Attach request correlation context and collect request metrics.""" request_id = request.headers.get("x-request-id") or str(uuid.uuid4()) set_request_id(request_id) request.state.request_id = request_id started_at = monotonic_seconds() status_code = 500 try: response = await call_next(request) status_code = response.status_code response.headers["X-Request-ID"] = request_id return response finally: duration = max(monotonic_seconds() - started_at, 0.0) logger.debug( "request_completed", extra={ "method": request.method, "path": request.url.path, "duration_seconds": duration, "status_code": status_code, }, ) metrics = get_metrics_registry() metrics.record_http_request(request.method, request.url.path, status_code) @app.middleware("http") async def authenticate_and_rate_limit( request: Request, call_next: Callable[[Request], Awaitable[Response]], ) -> Response: """Apply rate-limiting and authentication for MCP endpoints.""" settings = get_settings() if request.url.path in {"/", "/health"}: return await call_next(request) if request.url.path == "/metrics" and settings.metrics_enabled: # Metrics endpoint is intentionally left unauthenticated for pull-based scraping. return await call_next(request) if not (request.url.path.startswith("/mcp/") or request.url.path.startswith("/automation/")): return await call_next(request) validator = get_validator() limiter = get_rate_limiter() client_ip = request.client.host if request.client else "unknown" user_agent = request.headers.get("user-agent", "unknown") auth_header = request.headers.get("authorization") api_key = validator.extract_bearer_token(auth_header) if not api_key and request.url.path in {"/mcp/tool/call", "/mcp/sse"}: api_key = request.query_params.get("api_key") rate_limit = limiter.check(client_ip=client_ip, token=api_key) if not rate_limit.allowed: return JSONResponse( status_code=429, content={ "error": "Rate limit exceeded", "message": rate_limit.reason, "request_id": getattr(request.state, "request_id", "-"), }, ) # Mixed mode: tool discovery remains public to preserve MCP client compatibility. if request.url.path == "/mcp/tools": return await call_next(request) is_valid, error_message = validator.validate_api_key(api_key, client_ip, user_agent) if not is_valid: return JSONResponse( status_code=401, content={ "error": "Authentication failed", "message": error_message, "detail": "Provide Authorization: Bearer or ?api_key=", "request_id": getattr(request.state, "request_id", "-"), }, ) return await call_next(request) @app.on_event("startup") async def startup_event() -> None: """Initialize server state on startup.""" settings = get_settings() configure_logging(settings.log_level) logger.info("server_starting") logger.info( "server_configuration", extra={ "host": settings.mcp_host, "port": settings.mcp_port, "gitea_url": settings.gitea_base_url, "auth_enabled": settings.auth_enabled, "write_mode": settings.write_mode, "metrics_enabled": settings.metrics_enabled, }, ) # Fail-fast policy parse errors at startup. try: _ = get_policy_engine() except PolicyError: logger.error("policy_load_failed") raise if settings.startup_validate_gitea and settings.environment != "test": try: async with GiteaClient() as gitea: user = await gitea.get_current_user() logger.info("gitea_connected", extra={"bot_user": user.get("login", "unknown")}) except GiteaAuthenticationError as exc: logger.error("gitea_connection_failed_authentication") raise RuntimeError( "Startup validation failed: Gitea authentication was rejected. Check GITEA_TOKEN." ) from exc except GiteaAuthorizationError as exc: logger.error("gitea_connection_failed_authorization") raise RuntimeError( "Startup validation failed: Gitea token lacks permission for /api/v1/user." ) from exc except Exception as exc: logger.error("gitea_connection_failed") raise RuntimeError("Startup validation failed: unable to connect to Gitea.") from exc @app.on_event("shutdown") async def shutdown_event() -> None: """Log server shutdown event.""" logger.info("server_stopping") @app.get("/") async def root() -> dict[str, Any]: """Root endpoint with server metadata.""" return { "name": "AegisGitea MCP Server", "version": "0.2.0", "status": "running", "mcp_version": "1.0", } @app.get("/health") async def health() -> dict[str, str]: """Health check endpoint.""" return {"status": "healthy"} @app.get("/metrics") async def metrics() -> PlainTextResponse: """Prometheus-compatible metrics endpoint.""" settings = get_settings() if not settings.metrics_enabled: raise HTTPException(status_code=404, detail="Metrics endpoint disabled") data = get_metrics_registry().render_prometheus() return PlainTextResponse(content=data, media_type="text/plain; version=0.0.4") @app.post("/automation/webhook") async def automation_webhook(request: AutomationWebhookRequest) -> JSONResponse: """Ingest policy-controlled automation webhooks.""" manager = AutomationManager() try: result = await manager.handle_webhook( event_type=request.event_type, payload=request.payload, repository=request.repository, ) return JSONResponse(content={"success": True, "result": result}) except AutomationError as exc: raise HTTPException(status_code=403, detail=str(exc)) from exc @app.post("/automation/jobs/run") async def automation_run_job(request: AutomationJobRequest) -> JSONResponse: """Execute a policy-controlled automation job for a repository.""" manager = AutomationManager() try: result = await manager.run_job( job_name=request.job_name, owner=request.owner, repo=request.repo, finding_title=request.finding_title, finding_body=request.finding_body, ) return JSONResponse(content={"success": True, "result": result}) except AutomationError as exc: raise HTTPException(status_code=403, detail=str(exc)) from exc @app.get("/mcp/tools") async def list_tools() -> JSONResponse: """List all available MCP tools.""" response = MCPListToolsResponse(tools=AVAILABLE_TOOLS) return JSONResponse(content=response.model_dump(by_alias=True)) async def _execute_tool_call( tool_name: str, arguments: dict[str, Any], correlation_id: str ) -> dict[str, Any]: """Execute tool call with policy checks and standardized response sanitization.""" settings = get_settings() audit = get_audit_logger() metrics = get_metrics_registry() tool_def = get_tool_by_name(tool_name) if not tool_def: raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found") handler = TOOL_HANDLERS.get(tool_name) if not handler: raise HTTPException( status_code=500, detail=f"Tool '{tool_name}' has no handler implementation" ) repository = extract_repository(arguments) target_path = extract_target_path(arguments) decision = get_policy_engine().authorize( tool_name=tool_name, is_write=tool_def.write_operation, repository=repository, target_path=target_path, ) if not decision.allowed: audit.log_access_denied( tool_name=tool_name, repository=repository, reason=decision.reason, correlation_id=correlation_id, ) raise HTTPException(status_code=403, detail=f"Policy denied request: {decision.reason}") started_at = monotonic_seconds() status = "error" try: async with GiteaClient() as gitea: result = await handler(gitea, arguments) if settings.secret_detection_mode != "off": # Security decision: sanitize outbound payloads to prevent accidental secret exfiltration. result = sanitize_data(result, mode=settings.secret_detection_mode) status = "success" return result finally: duration = max(monotonic_seconds() - started_at, 0.0) metrics.record_tool_call(tool_name, status, duration) @app.post("/mcp/tool/call") async def call_tool(request: MCPToolCallRequest) -> JSONResponse: """Execute an MCP tool call.""" settings = get_settings() audit = get_audit_logger() correlation_id = request.correlation_id or audit.log_tool_invocation( tool_name=request.tool, params=request.arguments, ) try: result = await _execute_tool_call(request.tool, request.arguments, correlation_id) audit.log_tool_invocation( tool_name=request.tool, correlation_id=correlation_id, result_status="success", ) return JSONResponse( content=MCPToolCallResponse( success=True, result=result, correlation_id=correlation_id, ).model_dump() ) except HTTPException as exc: audit.log_tool_invocation( tool_name=request.tool, correlation_id=correlation_id, result_status="error", error=str(exc.detail), ) raise except ValidationError as exc: error_message = "Invalid tool arguments" if settings.expose_error_details: error_message = f"{error_message}: {exc}" audit.log_tool_invocation( tool_name=request.tool, correlation_id=correlation_id, result_status="error", error="validation_error", ) raise HTTPException(status_code=400, detail=error_message) from exc except Exception: # Security decision: do not leak stack traces or raw exception messages. error_message = "Internal server error" if settings.expose_error_details: error_message = "Internal server error (details hidden unless explicitly enabled)" audit.log_tool_invocation( tool_name=request.tool, correlation_id=correlation_id, result_status="error", error="internal_error", ) logger.exception("tool_execution_failed") return JSONResponse( status_code=500, content=MCPToolCallResponse( success=False, error=error_message, correlation_id=correlation_id, ).model_dump(), ) @app.get("/mcp/sse") async def sse_endpoint(request: Request) -> StreamingResponse: """Server-Sent Events endpoint for MCP transport.""" async def event_stream() -> AsyncGenerator[str, None]: yield ( "data: " + json.dumps( {"event": "connected", "server": "AegisGitea MCP", "version": "0.2.0"}, separators=(",", ":"), ) + "\n\n" ) try: while True: if await request.is_disconnected(): break yield 'data: {"event":"heartbeat"}\n\n' await asyncio.sleep(30) except Exception: logger.exception("sse_stream_error") return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) @app.post("/mcp/sse") async def sse_message_handler(request: Request) -> JSONResponse: """Handle POST messages for MCP SSE transport.""" settings = get_settings() audit = get_audit_logger() try: body = await request.json() message_type = body.get("type") or body.get("method") message_id = body.get("id") if message_type == "initialize": return JSONResponse( content={ "jsonrpc": "2.0", "id": message_id, "result": { "protocolVersion": "2024-11-05", "capabilities": {"tools": {}}, "serverInfo": {"name": "AegisGitea MCP", "version": "0.2.0"}, }, } ) if message_type == "tools/list": response = MCPListToolsResponse(tools=AVAILABLE_TOOLS) return JSONResponse( content={ "jsonrpc": "2.0", "id": message_id, "result": response.model_dump(by_alias=True), } ) if message_type == "tools/call": tool_name = body.get("params", {}).get("name") tool_args = body.get("params", {}).get("arguments", {}) correlation_id = audit.log_tool_invocation(tool_name=tool_name, params=tool_args) try: result = await _execute_tool_call(str(tool_name), tool_args, correlation_id) audit.log_tool_invocation( tool_name=str(tool_name), correlation_id=correlation_id, result_status="success", ) return JSONResponse( content={ "jsonrpc": "2.0", "id": message_id, "result": {"content": [{"type": "text", "text": json.dumps(result)}]}, } ) except Exception as exc: audit.log_tool_invocation( tool_name=str(tool_name), correlation_id=correlation_id, result_status="error", error=str(exc), ) message = "Internal server error" if settings.expose_error_details: message = str(exc) return JSONResponse( content={ "jsonrpc": "2.0", "id": message_id, "error": {"code": -32603, "message": message}, } ) if isinstance(message_type, str) and message_type.startswith("notifications/"): return JSONResponse(content={}) return JSONResponse( content={"jsonrpc": "2.0", "id": message_id, "result": {"acknowledged": True}} ) except Exception: logger.exception("sse_message_handler_error") message = "Invalid message format" if settings.expose_error_details: message = "Invalid message format (details hidden unless explicitly enabled)" return JSONResponse(status_code=400, content={"error": message}) def main() -> None: """Run the MCP server.""" import uvicorn settings = get_settings() uvicorn.run( "aegis_gitea_mcp.server:app", host=settings.mcp_host, port=settings.mcp_port, log_level=settings.log_level.lower(), reload=False, ) if __name__ == "__main__": main()