feat: harden gateway with policy engine, secure tools, and governance docs
This commit is contained in:
+159
-102
@@ -1,50 +1,110 @@
|
||||
"""Audit logging system for MCP tool invocations."""
|
||||
"""Tamper-evident audit logging for MCP tool invocations and security events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import structlog
|
||||
from typing import Any
|
||||
|
||||
from aegis_gitea_mcp.config import get_settings
|
||||
from aegis_gitea_mcp.request_context import get_request_id
|
||||
from aegis_gitea_mcp.security import sanitize_data
|
||||
|
||||
_GENESIS_HASH = "GENESIS"
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""Audit logger for tracking all MCP tool invocations."""
|
||||
"""Append-only tamper-evident audit logger.
|
||||
|
||||
def __init__(self, log_path: Optional[Path] = None) -> None:
|
||||
Every line in the audit file is hash-chained to the previous line. This makes
|
||||
post-hoc modifications detectable by integrity validation.
|
||||
"""
|
||||
|
||||
def __init__(self, log_path: Path | None = None) -> None:
|
||||
"""Initialize audit logger.
|
||||
|
||||
Args:
|
||||
log_path: Path to audit log file (defaults to config value)
|
||||
log_path: Path to audit log file (defaults to config value).
|
||||
"""
|
||||
self.settings = get_settings()
|
||||
self.log_path = log_path or self.settings.audit_log_path
|
||||
|
||||
# Ensure log directory exists
|
||||
self.log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._log_file = self._get_log_file()
|
||||
self._lock = threading.Lock()
|
||||
self._log_file = open(self.log_path, "a+", encoding="utf-8")
|
||||
self._last_hash = self._read_last_hash()
|
||||
|
||||
# Configure structlog for audit logging
|
||||
structlog.configure(
|
||||
processors=[
|
||||
structlog.processors.TimeStamper(fmt="iso", utc=True),
|
||||
structlog.processors.dict_tracebacks,
|
||||
structlog.processors.JSONRenderer(),
|
||||
],
|
||||
wrapper_class=structlog.BoundLogger,
|
||||
context_class=dict,
|
||||
logger_factory=structlog.PrintLoggerFactory(file=self._log_file),
|
||||
cache_logger_on_first_use=True,
|
||||
def _read_last_hash(self) -> str:
|
||||
"""Read the previous hash from the last log entry."""
|
||||
try:
|
||||
entries = self.log_path.read_text(encoding="utf-8").splitlines()
|
||||
except FileNotFoundError:
|
||||
return _GENESIS_HASH
|
||||
|
||||
if not entries:
|
||||
return _GENESIS_HASH
|
||||
|
||||
last_line = entries[-1]
|
||||
try:
|
||||
payload = json.loads(last_line)
|
||||
entry_hash = payload.get("entry_hash")
|
||||
if isinstance(entry_hash, str) and entry_hash:
|
||||
return entry_hash
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Corrupt trailing line forces a new chain segment.
|
||||
return _GENESIS_HASH
|
||||
|
||||
@staticmethod
|
||||
def _compute_entry_hash(prev_hash: str, payload: dict[str, Any]) -> str:
|
||||
"""Compute deterministic hash for an audit entry payload."""
|
||||
canonical = json.dumps(payload, sort_keys=True, separators=(",", ":"), ensure_ascii=True)
|
||||
digest = hashlib.sha256(f"{prev_hash}:{canonical}".encode()).hexdigest()
|
||||
return digest
|
||||
|
||||
def _append_entry(self, event_type: str, payload: dict[str, Any]) -> str:
|
||||
"""Append a hash-chained entry to audit log.
|
||||
|
||||
Args:
|
||||
event_type: Event category.
|
||||
payload: Event payload data.
|
||||
|
||||
Returns:
|
||||
Correlation ID for the appended entry.
|
||||
"""
|
||||
correlation_id = payload.get("correlation_id")
|
||||
if not isinstance(correlation_id, str) or not correlation_id:
|
||||
correlation_id = str(uuid.uuid4())
|
||||
payload["correlation_id"] = correlation_id
|
||||
|
||||
# Security decision: sanitize all audit payloads before persistence.
|
||||
mode = "mask" if self.settings.secret_detection_mode != "off" else "off"
|
||||
safe_payload = payload if mode == "off" else sanitize_data(payload, mode=mode)
|
||||
|
||||
base_entry: dict[str, Any] = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"event_type": event_type,
|
||||
"request_id": get_request_id(),
|
||||
"payload": safe_payload,
|
||||
"prev_hash": self._last_hash,
|
||||
}
|
||||
entry_hash = self._compute_entry_hash(self._last_hash, base_entry)
|
||||
base_entry["entry_hash"] = entry_hash
|
||||
|
||||
serialized = json.dumps(
|
||||
base_entry, sort_keys=True, separators=(",", ":"), ensure_ascii=True
|
||||
)
|
||||
|
||||
self.logger = structlog.get_logger("audit")
|
||||
with self._lock:
|
||||
self._log_file.write(serialized + "\n")
|
||||
self._log_file.flush()
|
||||
self._last_hash = entry_hash
|
||||
|
||||
def _get_log_file(self) -> Any:
|
||||
"""Get file handle for audit log."""
|
||||
return open(self.log_path, "a", encoding="utf-8")
|
||||
return correlation_id
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close open audit log resources."""
|
||||
@@ -56,111 +116,108 @@ class AuditLogger:
|
||||
def log_tool_invocation(
|
||||
self,
|
||||
tool_name: str,
|
||||
repository: Optional[str] = None,
|
||||
target: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
correlation_id: Optional[str] = None,
|
||||
repository: str | None = None,
|
||||
target: str | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
result_status: str = "pending",
|
||||
error: Optional[str] = None,
|
||||
error: str | None = None,
|
||||
) -> str:
|
||||
"""Log an MCP tool invocation.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the MCP tool being invoked
|
||||
repository: Repository identifier (owner/repo)
|
||||
target: Target path, commit hash, issue number, etc.
|
||||
params: Additional parameters passed to the tool
|
||||
correlation_id: Request correlation ID (auto-generated if not provided)
|
||||
result_status: Status of the invocation (pending, success, error)
|
||||
error: Error message if invocation failed
|
||||
|
||||
Returns:
|
||||
Correlation ID for this invocation
|
||||
"""
|
||||
if correlation_id is None:
|
||||
correlation_id = str(uuid.uuid4())
|
||||
|
||||
audit_entry = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"correlation_id": correlation_id,
|
||||
"""Log an MCP tool invocation."""
|
||||
payload: dict[str, Any] = {
|
||||
"correlation_id": correlation_id or str(uuid.uuid4()),
|
||||
"tool_name": tool_name,
|
||||
"repository": repository,
|
||||
"target": target,
|
||||
"params": params or {},
|
||||
"result_status": result_status,
|
||||
}
|
||||
|
||||
if error:
|
||||
audit_entry["error"] = error
|
||||
|
||||
self.logger.info("tool_invocation", **audit_entry)
|
||||
return correlation_id
|
||||
payload["error"] = error
|
||||
return self._append_entry("tool_invocation", payload)
|
||||
|
||||
def log_access_denied(
|
||||
self,
|
||||
tool_name: str,
|
||||
repository: Optional[str] = None,
|
||||
repository: str | None = None,
|
||||
reason: str = "unauthorized",
|
||||
correlation_id: Optional[str] = None,
|
||||
correlation_id: str | None = None,
|
||||
) -> str:
|
||||
"""Log an access denial event.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool that was denied access
|
||||
repository: Repository identifier that access was denied to
|
||||
reason: Reason for denial
|
||||
correlation_id: Request correlation ID
|
||||
|
||||
Returns:
|
||||
Correlation ID for this event
|
||||
"""
|
||||
if correlation_id is None:
|
||||
correlation_id = str(uuid.uuid4())
|
||||
|
||||
self.logger.warning(
|
||||
"""Log an access denial event."""
|
||||
return self._append_entry(
|
||||
"access_denied",
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
correlation_id=correlation_id,
|
||||
tool_name=tool_name,
|
||||
repository=repository,
|
||||
reason=reason,
|
||||
{
|
||||
"correlation_id": correlation_id or str(uuid.uuid4()),
|
||||
"tool_name": tool_name,
|
||||
"repository": repository,
|
||||
"reason": reason,
|
||||
},
|
||||
)
|
||||
return correlation_id
|
||||
|
||||
def log_security_event(
|
||||
self,
|
||||
event_type: str,
|
||||
description: str,
|
||||
severity: str = "medium",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Log a security-related event.
|
||||
|
||||
Args:
|
||||
event_type: Type of security event (e.g., rate_limit, invalid_input)
|
||||
description: Human-readable description of the event
|
||||
severity: Severity level (low, medium, high, critical)
|
||||
metadata: Additional metadata about the event
|
||||
|
||||
Returns:
|
||||
Correlation ID for this event
|
||||
"""
|
||||
correlation_id = str(uuid.uuid4())
|
||||
|
||||
self.logger.warning(
|
||||
"""Log a security event."""
|
||||
return self._append_entry(
|
||||
"security_event",
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
correlation_id=correlation_id,
|
||||
event_type=event_type,
|
||||
description=description,
|
||||
severity=severity,
|
||||
metadata=metadata or {},
|
||||
{
|
||||
"event_type": event_type,
|
||||
"description": description,
|
||||
"severity": severity,
|
||||
"metadata": metadata or {},
|
||||
},
|
||||
)
|
||||
return correlation_id
|
||||
|
||||
|
||||
# Global audit logger instance
|
||||
_audit_logger: Optional[AuditLogger] = None
|
||||
def validate_audit_log_integrity(log_path: Path) -> tuple[bool, list[str]]:
|
||||
"""Validate audit log hash chain integrity.
|
||||
|
||||
Args:
|
||||
log_path: Path to an audit log file.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, errors).
|
||||
"""
|
||||
if not log_path.exists():
|
||||
return True, []
|
||||
|
||||
errors: list[str] = []
|
||||
prev_hash = _GENESIS_HASH
|
||||
|
||||
for line_number, raw_line in enumerate(
|
||||
log_path.read_text(encoding="utf-8").splitlines(), start=1
|
||||
):
|
||||
if not raw_line.strip():
|
||||
continue
|
||||
try:
|
||||
entry = json.loads(raw_line)
|
||||
except json.JSONDecodeError:
|
||||
errors.append(f"line {line_number}: invalid JSON")
|
||||
continue
|
||||
|
||||
line_prev_hash = entry.get("prev_hash")
|
||||
entry_hash = entry.get("entry_hash")
|
||||
if line_prev_hash != prev_hash:
|
||||
errors.append(f"line {line_number}: prev_hash mismatch")
|
||||
|
||||
# Recompute hash after removing the stored entry hash.
|
||||
cloned = dict(entry)
|
||||
cloned.pop("entry_hash", None)
|
||||
expected_hash = AuditLogger._compute_entry_hash(prev_hash, cloned)
|
||||
if entry_hash != expected_hash:
|
||||
errors.append(f"line {line_number}: entry_hash mismatch")
|
||||
prev_hash = expected_hash
|
||||
else:
|
||||
prev_hash = str(entry_hash)
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
|
||||
_audit_logger: AuditLogger | None = None
|
||||
|
||||
|
||||
def get_audit_logger() -> AuditLogger:
|
||||
|
||||
Reference in New Issue
Block a user