update
This commit is contained in:
@@ -39,10 +39,10 @@ app = FastAPI(
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
# Global settings, audit logger, and auth validator
|
||||
# Global settings and audit logger
|
||||
# Note: auth_validator is fetched dynamically in middleware to support test resets
|
||||
settings = get_settings()
|
||||
audit = get_audit_logger()
|
||||
auth_validator = get_validator()
|
||||
|
||||
|
||||
# Tool dispatcher mapping
|
||||
@@ -80,16 +80,19 @@ async def authenticate_request(request: Request, call_next):
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_agent = request.headers.get("user-agent", "unknown")
|
||||
|
||||
# Get validator instance (supports test resets)
|
||||
validator = get_validator()
|
||||
|
||||
# Extract Authorization header
|
||||
auth_header = request.headers.get("authorization")
|
||||
api_key = auth_validator.extract_bearer_token(auth_header)
|
||||
api_key = validator.extract_bearer_token(auth_header)
|
||||
|
||||
# Fallback: allow API key via query parameter only for MCP endpoints
|
||||
if not api_key and request.url.path in {"/mcp/tool/call", "/mcp/sse"}:
|
||||
api_key = request.query_params.get("api_key")
|
||||
|
||||
# Validate API key
|
||||
is_valid, error_message = auth_validator.validate_api_key(api_key, client_ip, user_agent)
|
||||
is_valid, error_message = validator.validate_api_key(api_key, client_ip, user_agent)
|
||||
|
||||
if not is_valid:
|
||||
return JSONResponse(
|
||||
@@ -224,6 +227,10 @@ async def call_tool(request: MCPToolCallRequest) -> JSONResponse:
|
||||
)
|
||||
return JSONResponse(content=response.model_dump())
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions (like 404) without catching them
|
||||
raise
|
||||
|
||||
except ValidationError as e:
|
||||
error_msg = f"Invalid arguments: {str(e)}"
|
||||
audit.log_tool_invocation(
|
||||
@@ -293,6 +300,112 @@ async def sse_endpoint(request: Request) -> StreamingResponse:
|
||||
)
|
||||
|
||||
|
||||
@app.post("/mcp/sse")
|
||||
async def sse_message_handler(request: Request) -> JSONResponse:
|
||||
"""Handle POST messages from ChatGPT MCP client to SSE endpoint.
|
||||
|
||||
The MCP SSE transport uses:
|
||||
- GET /mcp/sse for server-to-client streaming
|
||||
- POST /mcp/sse for client-to-server messages
|
||||
|
||||
Returns:
|
||||
JSON response acknowledging the message
|
||||
"""
|
||||
try:
|
||||
body = await request.json()
|
||||
logger.info(f"Received MCP message via SSE POST: {body}")
|
||||
|
||||
# Handle different message types
|
||||
message_type = body.get("type") or body.get("method")
|
||||
message_id = body.get("id")
|
||||
|
||||
if message_type == "initialize":
|
||||
return JSONResponse(
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"result": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"serverInfo": {"name": "AegisGitea MCP", "version": "0.1.0"},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
elif message_type == "tools/list":
|
||||
# Return the list of available tools
|
||||
response = MCPListToolsResponse(tools=AVAILABLE_TOOLS)
|
||||
return JSONResponse(
|
||||
content={"jsonrpc": "2.0", "id": message_id, "result": response.model_dump()}
|
||||
)
|
||||
|
||||
elif message_type == "tools/call":
|
||||
# Handle tool execution
|
||||
tool_name = body.get("params", {}).get("name")
|
||||
tool_args = body.get("params", {}).get("arguments", {})
|
||||
|
||||
correlation_id = audit.log_tool_invocation(
|
||||
tool_name=tool_name,
|
||||
params=tool_args,
|
||||
)
|
||||
|
||||
try:
|
||||
# Get tool handler
|
||||
handler = TOOL_HANDLERS.get(tool_name)
|
||||
if not handler:
|
||||
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
|
||||
|
||||
# Execute tool with Gitea client
|
||||
async with GiteaClient() as gitea:
|
||||
result = await handler(gitea, tool_args)
|
||||
|
||||
audit.log_tool_invocation(
|
||||
tool_name=tool_name,
|
||||
correlation_id=correlation_id,
|
||||
result_status="success",
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"result": {"content": [{"type": "text", "text": str(result)}]},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
audit.log_tool_invocation(
|
||||
tool_name=tool_name,
|
||||
correlation_id=correlation_id,
|
||||
result_status="error",
|
||||
error=error_msg,
|
||||
)
|
||||
return JSONResponse(
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"error": {"code": -32603, "message": error_msg},
|
||||
}
|
||||
)
|
||||
|
||||
# Handle notifications (no response needed)
|
||||
elif message_type and message_type.startswith("notifications/"):
|
||||
logger.info(f"Received notification: {message_type}")
|
||||
return JSONResponse(content={})
|
||||
|
||||
# Acknowledge other message types
|
||||
return JSONResponse(
|
||||
content={"jsonrpc": "2.0", "id": message_id, "result": {"acknowledged": True}}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling SSE POST message: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400, content={"error": "Invalid message format", "detail": str(e)}
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the MCP server."""
|
||||
import uvicorn
|
||||
|
||||
Reference in New Issue
Block a user