feat: harden gateway with policy engine, secure tools, and governance docs
This commit is contained in:
@@ -1,16 +1,24 @@
|
||||
"""Main MCP server implementation with FastAPI and SSE support."""
|
||||
"""Main MCP server implementation with hardened security controls."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import ValidationError
|
||||
from fastapi import FastAPI, HTTPException, Request, Response
|
||||
from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from aegis_gitea_mcp.audit import get_audit_logger
|
||||
from aegis_gitea_mcp.auth import get_validator
|
||||
from aegis_gitea_mcp.automation import AutomationError, AutomationManager
|
||||
from aegis_gitea_mcp.config import get_settings
|
||||
from aegis_gitea_mcp.gitea_client import GiteaClient
|
||||
from aegis_gitea_mcp.logging_utils import configure_logging
|
||||
from aegis_gitea_mcp.mcp_protocol import (
|
||||
AVAILABLE_TOOLS,
|
||||
MCPListToolsResponse,
|
||||
@@ -18,276 +26,443 @@ from aegis_gitea_mcp.mcp_protocol import (
|
||||
MCPToolCallResponse,
|
||||
get_tool_by_name,
|
||||
)
|
||||
from aegis_gitea_mcp.observability import get_metrics_registry, monotonic_seconds
|
||||
from aegis_gitea_mcp.policy import PolicyError, get_policy_engine
|
||||
from aegis_gitea_mcp.rate_limit import get_rate_limiter
|
||||
from aegis_gitea_mcp.request_context import set_request_id
|
||||
from aegis_gitea_mcp.security import sanitize_data
|
||||
from aegis_gitea_mcp.tools.arguments import extract_repository, extract_target_path
|
||||
from aegis_gitea_mcp.tools.read_tools import (
|
||||
compare_refs_tool,
|
||||
get_commit_diff_tool,
|
||||
get_issue_tool,
|
||||
get_pull_request_tool,
|
||||
list_commits_tool,
|
||||
list_issues_tool,
|
||||
list_labels_tool,
|
||||
list_pull_requests_tool,
|
||||
list_releases_tool,
|
||||
list_tags_tool,
|
||||
search_code_tool,
|
||||
)
|
||||
from aegis_gitea_mcp.tools.repository import (
|
||||
get_file_contents_tool,
|
||||
get_file_tree_tool,
|
||||
get_repository_info_tool,
|
||||
list_repositories_tool,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
from aegis_gitea_mcp.tools.write_tools import (
|
||||
add_labels_tool,
|
||||
assign_issue_tool,
|
||||
create_issue_comment_tool,
|
||||
create_issue_tool,
|
||||
create_pr_comment_tool,
|
||||
update_issue_tool,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
title="AegisGitea MCP Server",
|
||||
description="Security-first MCP server for controlled AI access to self-hosted Gitea",
|
||||
version="0.1.0",
|
||||
version="0.2.0",
|
||||
)
|
||||
|
||||
# Global settings and audit logger
|
||||
# Note: access settings/audit logger dynamically to support test resets.
|
||||
|
||||
class AutomationWebhookRequest(BaseModel):
|
||||
"""Request body for automation webhook ingestion."""
|
||||
|
||||
event_type: str = Field(..., min_length=1, max_length=128)
|
||||
payload: dict[str, Any] = Field(default_factory=dict)
|
||||
repository: str | None = Field(default=None)
|
||||
|
||||
|
||||
# Tool dispatcher mapping
|
||||
TOOL_HANDLERS = {
|
||||
class AutomationJobRequest(BaseModel):
|
||||
"""Request body for automation job execution."""
|
||||
|
||||
job_name: str = Field(..., min_length=1, max_length=128)
|
||||
owner: str = Field(..., min_length=1, max_length=100)
|
||||
repo: str = Field(..., min_length=1, max_length=100)
|
||||
finding_title: str | None = Field(default=None, max_length=256)
|
||||
finding_body: str | None = Field(default=None, max_length=10_000)
|
||||
|
||||
|
||||
ToolHandler = Callable[[GiteaClient, dict[str, Any]], Awaitable[dict[str, Any]]]
|
||||
|
||||
TOOL_HANDLERS: dict[str, ToolHandler] = {
|
||||
# Baseline read tools
|
||||
"list_repositories": list_repositories_tool,
|
||||
"get_repository_info": get_repository_info_tool,
|
||||
"get_file_tree": get_file_tree_tool,
|
||||
"get_file_contents": get_file_contents_tool,
|
||||
# Expanded read tools
|
||||
"search_code": search_code_tool,
|
||||
"list_commits": list_commits_tool,
|
||||
"get_commit_diff": get_commit_diff_tool,
|
||||
"compare_refs": compare_refs_tool,
|
||||
"list_issues": list_issues_tool,
|
||||
"get_issue": get_issue_tool,
|
||||
"list_pull_requests": list_pull_requests_tool,
|
||||
"get_pull_request": get_pull_request_tool,
|
||||
"list_labels": list_labels_tool,
|
||||
"list_tags": list_tags_tool,
|
||||
"list_releases": list_releases_tool,
|
||||
# Write-mode tools
|
||||
"create_issue": create_issue_tool,
|
||||
"update_issue": update_issue_tool,
|
||||
"create_issue_comment": create_issue_comment_tool,
|
||||
"create_pr_comment": create_pr_comment_tool,
|
||||
"add_labels": add_labels_tool,
|
||||
"assign_issue": assign_issue_tool,
|
||||
}
|
||||
|
||||
|
||||
# Authentication middleware
|
||||
@app.middleware("http")
|
||||
async def authenticate_request(request: Request, call_next):
|
||||
"""Authenticate all requests except health checks and root.
|
||||
async def request_context_middleware(
|
||||
request: Request,
|
||||
call_next: Callable[[Request], Awaitable[Response]],
|
||||
) -> Response:
|
||||
"""Attach request correlation context and collect request metrics."""
|
||||
request_id = request.headers.get("x-request-id") or str(uuid.uuid4())
|
||||
set_request_id(request_id)
|
||||
request.state.request_id = request_id
|
||||
|
||||
Supports Mixed authentication mode where:
|
||||
- /mcp/tools (list tools) is publicly accessible (No Auth)
|
||||
- /mcp/tool/call (execute tools) requires authentication
|
||||
- /mcp/sse requires authentication
|
||||
"""
|
||||
# Skip authentication for health check and root endpoints
|
||||
if request.url.path in ["/", "/health"]:
|
||||
started_at = monotonic_seconds()
|
||||
status_code = 500
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
status_code = response.status_code
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
return response
|
||||
finally:
|
||||
duration = max(monotonic_seconds() - started_at, 0.0)
|
||||
logger.debug(
|
||||
"request_completed",
|
||||
extra={
|
||||
"method": request.method,
|
||||
"path": request.url.path,
|
||||
"duration_seconds": duration,
|
||||
"status_code": status_code,
|
||||
},
|
||||
)
|
||||
metrics = get_metrics_registry()
|
||||
metrics.record_http_request(request.method, request.url.path, status_code)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def authenticate_and_rate_limit(
|
||||
request: Request,
|
||||
call_next: Callable[[Request], Awaitable[Response]],
|
||||
) -> Response:
|
||||
"""Apply rate-limiting and authentication for MCP endpoints."""
|
||||
settings = get_settings()
|
||||
|
||||
if request.url.path in {"/", "/health"}:
|
||||
return await call_next(request)
|
||||
|
||||
# Only authenticate MCP endpoints
|
||||
if not request.url.path.startswith("/mcp/"):
|
||||
if request.url.path == "/metrics" and settings.metrics_enabled:
|
||||
# Metrics endpoint is intentionally left unauthenticated for pull-based scraping.
|
||||
return await call_next(request)
|
||||
|
||||
# Mixed mode: allow /mcp/tools without authentication (for ChatGPT discovery)
|
||||
if request.url.path == "/mcp/tools":
|
||||
if not (request.url.path.startswith("/mcp/") or request.url.path.startswith("/automation/")):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract client information
|
||||
validator = get_validator()
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "unknown")
|
||||
|
||||
# Get validator instance (supports test resets)
|
||||
validator = get_validator()
|
||||
|
||||
# Extract Authorization header
|
||||
auth_header = request.headers.get("authorization")
|
||||
api_key = validator.extract_bearer_token(auth_header)
|
||||
|
||||
# Fallback: allow API key via query parameter only for MCP endpoints
|
||||
if not api_key and request.url.path in {"/mcp/tool/call", "/mcp/sse"}:
|
||||
api_key = request.query_params.get("api_key")
|
||||
|
||||
# Validate API key
|
||||
is_valid, error_message = validator.validate_api_key(api_key, client_ip, user_agent)
|
||||
rate_limit = limiter.check(client_ip=client_ip, token=api_key)
|
||||
if not rate_limit.allowed:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "Rate limit exceeded",
|
||||
"message": rate_limit.reason,
|
||||
"request_id": getattr(request.state, "request_id", "-"),
|
||||
},
|
||||
)
|
||||
|
||||
# Mixed mode: tool discovery remains public to preserve MCP client compatibility.
|
||||
if request.url.path == "/mcp/tools":
|
||||
return await call_next(request)
|
||||
|
||||
is_valid, error_message = validator.validate_api_key(api_key, client_ip, user_agent)
|
||||
if not is_valid:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"error": "Authentication failed",
|
||||
"message": error_message,
|
||||
"detail": (
|
||||
"Provide a valid API key via Authorization header (Bearer <api-key>) "
|
||||
"or ?api_key=<api-key> query parameter"
|
||||
),
|
||||
"detail": "Provide Authorization: Bearer <api-key> or ?api_key=<api-key>",
|
||||
"request_id": getattr(request.state, "request_id", "-"),
|
||||
},
|
||||
)
|
||||
|
||||
# Authentication successful - continue to endpoint
|
||||
response = await call_next(request)
|
||||
return response
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event() -> None:
|
||||
"""Initialize server on startup."""
|
||||
"""Initialize server state on startup."""
|
||||
settings = get_settings()
|
||||
logger.info(f"Starting AegisGitea MCP Server on {settings.mcp_host}:{settings.mcp_port}")
|
||||
logger.info(f"Connected to Gitea instance: {settings.gitea_base_url}")
|
||||
logger.info(f"Audit logging enabled: {settings.audit_log_path}")
|
||||
configure_logging(settings.log_level)
|
||||
|
||||
# Log authentication status
|
||||
if settings.auth_enabled:
|
||||
key_count = len(settings.mcp_api_keys)
|
||||
logger.info(f"API key authentication ENABLED ({key_count} key(s) configured)")
|
||||
else:
|
||||
logger.warning("API key authentication DISABLED - server is open to all requests!")
|
||||
logger.info("server_starting")
|
||||
logger.info(
|
||||
"server_configuration",
|
||||
extra={
|
||||
"host": settings.mcp_host,
|
||||
"port": settings.mcp_port,
|
||||
"gitea_url": settings.gitea_base_url,
|
||||
"auth_enabled": settings.auth_enabled,
|
||||
"write_mode": settings.write_mode,
|
||||
"metrics_enabled": settings.metrics_enabled,
|
||||
},
|
||||
)
|
||||
|
||||
# Test Gitea connection
|
||||
# Fail-fast policy parse errors at startup.
|
||||
try:
|
||||
async with GiteaClient() as gitea:
|
||||
user = await gitea.get_current_user()
|
||||
logger.info(f"Authenticated as bot user: {user.get('login', 'unknown')}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Gitea: {e}")
|
||||
_ = get_policy_engine()
|
||||
except PolicyError:
|
||||
logger.error("policy_load_failed")
|
||||
raise
|
||||
|
||||
if settings.startup_validate_gitea and settings.environment != "test":
|
||||
try:
|
||||
async with GiteaClient() as gitea:
|
||||
user = await gitea.get_current_user()
|
||||
logger.info("gitea_connected", extra={"bot_user": user.get("login", "unknown")})
|
||||
except Exception:
|
||||
logger.error("gitea_connection_failed")
|
||||
raise
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event() -> None:
|
||||
"""Cleanup on server shutdown."""
|
||||
logger.info("Shutting down AegisGitea MCP Server")
|
||||
"""Log server shutdown event."""
|
||||
logger.info("server_stopping")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root() -> Dict[str, Any]:
|
||||
"""Root endpoint with server information."""
|
||||
async def root() -> dict[str, Any]:
|
||||
"""Root endpoint with server metadata."""
|
||||
return {
|
||||
"name": "AegisGitea MCP Server",
|
||||
"version": "0.1.0",
|
||||
"version": "0.2.0",
|
||||
"status": "running",
|
||||
"mcp_version": "1.0",
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Dict[str, str]:
|
||||
async def health() -> dict[str, str]:
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
@app.get("/metrics")
|
||||
async def metrics() -> PlainTextResponse:
|
||||
"""Prometheus-compatible metrics endpoint."""
|
||||
settings = get_settings()
|
||||
if not settings.metrics_enabled:
|
||||
raise HTTPException(status_code=404, detail="Metrics endpoint disabled")
|
||||
data = get_metrics_registry().render_prometheus()
|
||||
return PlainTextResponse(content=data, media_type="text/plain; version=0.0.4")
|
||||
|
||||
|
||||
@app.post("/automation/webhook")
|
||||
async def automation_webhook(request: AutomationWebhookRequest) -> JSONResponse:
|
||||
"""Ingest policy-controlled automation webhooks."""
|
||||
manager = AutomationManager()
|
||||
try:
|
||||
result = await manager.handle_webhook(
|
||||
event_type=request.event_type,
|
||||
payload=request.payload,
|
||||
repository=request.repository,
|
||||
)
|
||||
return JSONResponse(content={"success": True, "result": result})
|
||||
except AutomationError as exc:
|
||||
raise HTTPException(status_code=403, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@app.post("/automation/jobs/run")
|
||||
async def automation_run_job(request: AutomationJobRequest) -> JSONResponse:
|
||||
"""Execute a policy-controlled automation job for a repository."""
|
||||
manager = AutomationManager()
|
||||
try:
|
||||
result = await manager.run_job(
|
||||
job_name=request.job_name,
|
||||
owner=request.owner,
|
||||
repo=request.repo,
|
||||
finding_title=request.finding_title,
|
||||
finding_body=request.finding_body,
|
||||
)
|
||||
return JSONResponse(content={"success": True, "result": result})
|
||||
except AutomationError as exc:
|
||||
raise HTTPException(status_code=403, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@app.get("/mcp/tools")
|
||||
async def list_tools() -> JSONResponse:
|
||||
"""List all available MCP tools.
|
||||
|
||||
Returns:
|
||||
JSON response with list of tool definitions
|
||||
"""
|
||||
"""List all available MCP tools."""
|
||||
response = MCPListToolsResponse(tools=AVAILABLE_TOOLS)
|
||||
return JSONResponse(content=response.model_dump(by_alias=True))
|
||||
return JSONResponse(content=response.model_dump())
|
||||
|
||||
|
||||
async def _execute_tool_call(
|
||||
tool_name: str, arguments: dict[str, Any], correlation_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""Execute tool call with policy checks and standardized response sanitization."""
|
||||
settings = get_settings()
|
||||
audit = get_audit_logger()
|
||||
metrics = get_metrics_registry()
|
||||
|
||||
tool_def = get_tool_by_name(tool_name)
|
||||
if not tool_def:
|
||||
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
|
||||
|
||||
handler = TOOL_HANDLERS.get(tool_name)
|
||||
if not handler:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Tool '{tool_name}' has no handler implementation"
|
||||
)
|
||||
|
||||
repository = extract_repository(arguments)
|
||||
target_path = extract_target_path(arguments)
|
||||
decision = get_policy_engine().authorize(
|
||||
tool_name=tool_name,
|
||||
is_write=tool_def.write_operation,
|
||||
repository=repository,
|
||||
target_path=target_path,
|
||||
)
|
||||
if not decision.allowed:
|
||||
audit.log_access_denied(
|
||||
tool_name=tool_name,
|
||||
repository=repository,
|
||||
reason=decision.reason,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
raise HTTPException(status_code=403, detail=f"Policy denied request: {decision.reason}")
|
||||
|
||||
started_at = monotonic_seconds()
|
||||
status = "error"
|
||||
|
||||
try:
|
||||
async with GiteaClient() as gitea:
|
||||
result = await handler(gitea, arguments)
|
||||
|
||||
if settings.secret_detection_mode != "off":
|
||||
# Security decision: sanitize outbound payloads to prevent accidental secret exfiltration.
|
||||
result = sanitize_data(result, mode=settings.secret_detection_mode)
|
||||
|
||||
status = "success"
|
||||
return result
|
||||
finally:
|
||||
duration = max(monotonic_seconds() - started_at, 0.0)
|
||||
metrics.record_tool_call(tool_name, status, duration)
|
||||
|
||||
|
||||
@app.post("/mcp/tool/call")
|
||||
async def call_tool(request: MCPToolCallRequest) -> JSONResponse:
|
||||
"""Execute an MCP tool call.
|
||||
|
||||
Args:
|
||||
request: Tool call request with tool name and arguments
|
||||
|
||||
Returns:
|
||||
JSON response with tool execution result
|
||||
"""
|
||||
"""Execute an MCP tool call."""
|
||||
settings = get_settings()
|
||||
audit = get_audit_logger()
|
||||
|
||||
correlation_id = request.correlation_id or audit.log_tool_invocation(
|
||||
tool_name=request.tool,
|
||||
params=request.arguments,
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate tool exists
|
||||
tool_def = get_tool_by_name(request.tool)
|
||||
if not tool_def:
|
||||
error_msg = f"Tool '{request.tool}' not found"
|
||||
audit.log_tool_invocation(
|
||||
tool_name=request.tool,
|
||||
correlation_id=correlation_id,
|
||||
result_status="error",
|
||||
error=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=404, detail=error_msg)
|
||||
|
||||
# Get tool handler
|
||||
handler = TOOL_HANDLERS.get(request.tool)
|
||||
if not handler:
|
||||
error_msg = f"Tool '{request.tool}' has no handler implementation"
|
||||
audit.log_tool_invocation(
|
||||
tool_name=request.tool,
|
||||
correlation_id=correlation_id,
|
||||
result_status="error",
|
||||
error=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
||||
# Execute tool with Gitea client
|
||||
async with GiteaClient() as gitea:
|
||||
result = await handler(gitea, request.arguments)
|
||||
|
||||
result = await _execute_tool_call(request.tool, request.arguments, correlation_id)
|
||||
audit.log_tool_invocation(
|
||||
tool_name=request.tool,
|
||||
correlation_id=correlation_id,
|
||||
result_status="success",
|
||||
)
|
||||
|
||||
response = MCPToolCallResponse(
|
||||
success=True,
|
||||
result=result,
|
||||
correlation_id=correlation_id,
|
||||
return JSONResponse(
|
||||
content=MCPToolCallResponse(
|
||||
success=True,
|
||||
result=result,
|
||||
correlation_id=correlation_id,
|
||||
).model_dump()
|
||||
)
|
||||
return JSONResponse(content=response.model_dump())
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions (like 404) without catching them
|
||||
except HTTPException as exc:
|
||||
audit.log_tool_invocation(
|
||||
tool_name=request.tool,
|
||||
correlation_id=correlation_id,
|
||||
result_status="error",
|
||||
error=str(exc.detail),
|
||||
)
|
||||
raise
|
||||
|
||||
except ValidationError as e:
|
||||
error_msg = f"Invalid arguments: {str(e)}"
|
||||
audit.log_tool_invocation(
|
||||
tool_name=request.tool,
|
||||
correlation_id=correlation_id,
|
||||
result_status="error",
|
||||
error=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
except ValidationError as exc:
|
||||
error_message = "Invalid tool arguments"
|
||||
if settings.expose_error_details:
|
||||
error_message = f"{error_message}: {exc}"
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
audit.log_tool_invocation(
|
||||
tool_name=request.tool,
|
||||
correlation_id=correlation_id,
|
||||
result_status="error",
|
||||
error=error_msg,
|
||||
error="validation_error",
|
||||
)
|
||||
response = MCPToolCallResponse(
|
||||
success=False,
|
||||
error=error_msg,
|
||||
raise HTTPException(status_code=400, detail=error_message) from exc
|
||||
|
||||
except Exception:
|
||||
# Security decision: do not leak stack traces or raw exception messages.
|
||||
error_message = "Internal server error"
|
||||
if settings.expose_error_details:
|
||||
error_message = "Internal server error (details hidden unless explicitly enabled)"
|
||||
|
||||
audit.log_tool_invocation(
|
||||
tool_name=request.tool,
|
||||
correlation_id=correlation_id,
|
||||
result_status="error",
|
||||
error="internal_error",
|
||||
)
|
||||
logger.exception("tool_execution_failed")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=MCPToolCallResponse(
|
||||
success=False,
|
||||
error=error_message,
|
||||
correlation_id=correlation_id,
|
||||
).model_dump(),
|
||||
)
|
||||
return JSONResponse(content=response.model_dump(), status_code=500)
|
||||
|
||||
|
||||
@app.get("/mcp/sse")
|
||||
async def sse_endpoint(request: Request) -> StreamingResponse:
|
||||
"""Server-Sent Events endpoint for MCP protocol.
|
||||
"""Server-Sent Events endpoint for MCP transport."""
|
||||
|
||||
This enables real-time communication with ChatGPT using SSE.
|
||||
async def event_stream() -> AsyncGenerator[str, None]:
|
||||
yield (
|
||||
"data: "
|
||||
+ json.dumps(
|
||||
{"event": "connected", "server": "AegisGitea MCP", "version": "0.2.0"},
|
||||
separators=(",", ":"),
|
||||
)
|
||||
+ "\n\n"
|
||||
)
|
||||
|
||||
Returns:
|
||||
Streaming SSE response
|
||||
"""
|
||||
|
||||
async def event_stream():
|
||||
"""Generate SSE events."""
|
||||
# Send initial connection event
|
||||
yield f"data: {{'event': 'connected', 'server': 'AegisGitea MCP', 'version': '0.1.0'}}\n\n"
|
||||
|
||||
# Keep connection alive
|
||||
try:
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
# Heartbeat every 30 seconds
|
||||
yield f"data: {{'event': 'heartbeat'}}\n\n"
|
||||
|
||||
# Wait for next heartbeat (in production, this would handle actual events)
|
||||
import asyncio
|
||||
|
||||
yield 'data: {"event":"heartbeat"}\n\n'
|
||||
await asyncio.sleep(30)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SSE stream error: {e}")
|
||||
except Exception:
|
||||
logger.exception("sse_stream_error")
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
@@ -302,21 +477,12 @@ async def sse_endpoint(request: Request) -> StreamingResponse:
|
||||
|
||||
@app.post("/mcp/sse")
|
||||
async def sse_message_handler(request: Request) -> JSONResponse:
|
||||
"""Handle POST messages from ChatGPT MCP client to SSE endpoint.
|
||||
"""Handle POST messages for MCP SSE transport."""
|
||||
settings = get_settings()
|
||||
audit = get_audit_logger()
|
||||
|
||||
The MCP SSE transport uses:
|
||||
- GET /mcp/sse for server-to-client streaming
|
||||
- POST /mcp/sse for client-to-server messages
|
||||
|
||||
Returns:
|
||||
JSON response acknowledging the message
|
||||
"""
|
||||
try:
|
||||
audit = get_audit_logger()
|
||||
body = await request.json()
|
||||
logger.info(f"Received MCP message via SSE POST: {body}")
|
||||
|
||||
# Handle different message types
|
||||
message_type = body.get("type") or body.get("method")
|
||||
message_id = body.get("id")
|
||||
|
||||
@@ -328,87 +494,71 @@ async def sse_message_handler(request: Request) -> JSONResponse:
|
||||
"result": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"serverInfo": {"name": "AegisGitea MCP", "version": "0.1.0"},
|
||||
"serverInfo": {"name": "AegisGitea MCP", "version": "0.2.0"},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
elif message_type == "tools/list":
|
||||
# Return the list of available tools
|
||||
if message_type == "tools/list":
|
||||
response = MCPListToolsResponse(tools=AVAILABLE_TOOLS)
|
||||
return JSONResponse(
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"result": response.model_dump(by_alias=True),
|
||||
"result": response.model_dump(),
|
||||
}
|
||||
)
|
||||
|
||||
elif message_type == "tools/call":
|
||||
# Handle tool execution
|
||||
if message_type == "tools/call":
|
||||
tool_name = body.get("params", {}).get("name")
|
||||
tool_args = body.get("params", {}).get("arguments", {})
|
||||
|
||||
correlation_id = audit.log_tool_invocation(
|
||||
tool_name=tool_name,
|
||||
params=tool_args,
|
||||
)
|
||||
|
||||
correlation_id = audit.log_tool_invocation(tool_name=tool_name, params=tool_args)
|
||||
try:
|
||||
# Get tool handler
|
||||
handler = TOOL_HANDLERS.get(tool_name)
|
||||
if not handler:
|
||||
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
|
||||
|
||||
# Execute tool with Gitea client
|
||||
async with GiteaClient() as gitea:
|
||||
result = await handler(gitea, tool_args)
|
||||
|
||||
result = await _execute_tool_call(str(tool_name), tool_args, correlation_id)
|
||||
audit.log_tool_invocation(
|
||||
tool_name=tool_name,
|
||||
tool_name=str(tool_name),
|
||||
correlation_id=correlation_id,
|
||||
result_status="success",
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"result": {"content": [{"type": "text", "text": str(result)}]},
|
||||
"result": {"content": [{"type": "text", "text": json.dumps(result)}]},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
except Exception as exc:
|
||||
audit.log_tool_invocation(
|
||||
tool_name=tool_name,
|
||||
tool_name=str(tool_name),
|
||||
correlation_id=correlation_id,
|
||||
result_status="error",
|
||||
error=error_msg,
|
||||
error=str(exc),
|
||||
)
|
||||
message = "Internal server error"
|
||||
if settings.expose_error_details:
|
||||
message = str(exc)
|
||||
return JSONResponse(
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"error": {"code": -32603, "message": error_msg},
|
||||
"error": {"code": -32603, "message": message},
|
||||
}
|
||||
)
|
||||
|
||||
# Handle notifications (no response needed)
|
||||
elif message_type and message_type.startswith("notifications/"):
|
||||
logger.info(f"Received notification: {message_type}")
|
||||
if isinstance(message_type, str) and message_type.startswith("notifications/"):
|
||||
return JSONResponse(content={})
|
||||
|
||||
# Acknowledge other message types
|
||||
return JSONResponse(
|
||||
content={"jsonrpc": "2.0", "id": message_id, "result": {"acknowledged": True}}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling SSE POST message: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400, content={"error": "Invalid message format", "detail": str(e)}
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("sse_message_handler_error")
|
||||
message = "Invalid message format"
|
||||
if settings.expose_error_details:
|
||||
message = "Invalid message format (details hidden unless explicitly enabled)"
|
||||
return JSONResponse(status_code=400, content={"error": message})
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
||||
Reference in New Issue
Block a user