Files
AegisGitea-MCP/src/aegis_gitea_mcp/audit.py
T

237 lines
7.4 KiB
Python

"""Tamper-evident audit logging for MCP tool invocations and security events."""
from __future__ import annotations
import hashlib
import json
import threading
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from aegis_gitea_mcp.config import get_settings
from aegis_gitea_mcp.request_context import get_request_id
from aegis_gitea_mcp.security import sanitize_data
_GENESIS_HASH = "GENESIS"
class AuditLogger:
"""Append-only tamper-evident audit logger.
Every line in the audit file is hash-chained to the previous line. This makes
post-hoc modifications detectable by integrity validation.
"""
def __init__(self, log_path: Path | None = None) -> None:
"""Initialize audit logger.
Args:
log_path: Path to audit log file (defaults to config value).
"""
self.settings = get_settings()
self.log_path = log_path or self.settings.audit_log_path
self.log_path.parent.mkdir(parents=True, exist_ok=True)
self._lock = threading.Lock()
self._log_file = open(self.log_path, "a+", encoding="utf-8")
self._last_hash = self._read_last_hash()
def _read_last_hash(self) -> str:
"""Read the previous hash from the last log entry."""
try:
entries = self.log_path.read_text(encoding="utf-8").splitlines()
except FileNotFoundError:
return _GENESIS_HASH
if not entries:
return _GENESIS_HASH
last_line = entries[-1]
try:
payload = json.loads(last_line)
entry_hash = payload.get("entry_hash")
if isinstance(entry_hash, str) and entry_hash:
return entry_hash
except json.JSONDecodeError:
pass
# Corrupt trailing line forces a new chain segment.
return _GENESIS_HASH
@staticmethod
def _compute_entry_hash(prev_hash: str, payload: dict[str, Any]) -> str:
"""Compute deterministic hash for an audit entry payload."""
canonical = json.dumps(payload, sort_keys=True, separators=(",", ":"), ensure_ascii=True)
digest = hashlib.sha256(f"{prev_hash}:{canonical}".encode()).hexdigest()
return digest
def _append_entry(self, event_type: str, payload: dict[str, Any]) -> str:
"""Append a hash-chained entry to audit log.
Args:
event_type: Event category.
payload: Event payload data.
Returns:
Correlation ID for the appended entry.
"""
correlation_id = payload.get("correlation_id")
if not isinstance(correlation_id, str) or not correlation_id:
correlation_id = str(uuid.uuid4())
payload["correlation_id"] = correlation_id
# Security decision: sanitize all audit payloads before persistence.
mode = "mask" if self.settings.secret_detection_mode != "off" else "off"
safe_payload = payload if mode == "off" else sanitize_data(payload, mode=mode)
base_entry: dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"event_type": event_type,
"request_id": get_request_id(),
"payload": safe_payload,
"prev_hash": self._last_hash,
}
entry_hash = self._compute_entry_hash(self._last_hash, base_entry)
base_entry["entry_hash"] = entry_hash
serialized = json.dumps(
base_entry, sort_keys=True, separators=(",", ":"), ensure_ascii=True
)
with self._lock:
self._log_file.write(serialized + "\n")
self._log_file.flush()
self._last_hash = entry_hash
return correlation_id
def close(self) -> None:
"""Close open audit log resources."""
try:
self._log_file.close()
except Exception:
pass
def log_tool_invocation(
self,
tool_name: str,
repository: str | None = None,
target: str | None = None,
params: dict[str, Any] | None = None,
correlation_id: str | None = None,
result_status: str = "pending",
error: str | None = None,
) -> str:
"""Log an MCP tool invocation."""
payload: dict[str, Any] = {
"correlation_id": correlation_id or str(uuid.uuid4()),
"tool_name": tool_name,
"repository": repository,
"target": target,
"params": params or {},
"result_status": result_status,
}
if error:
payload["error"] = error
return self._append_entry("tool_invocation", payload)
def log_access_denied(
self,
tool_name: str,
repository: str | None = None,
reason: str = "unauthorized",
correlation_id: str | None = None,
) -> str:
"""Log an access denial event."""
return self._append_entry(
"access_denied",
{
"correlation_id": correlation_id or str(uuid.uuid4()),
"tool_name": tool_name,
"repository": repository,
"reason": reason,
},
)
def log_security_event(
self,
event_type: str,
description: str,
severity: str = "medium",
metadata: dict[str, Any] | None = None,
) -> str:
"""Log a security event."""
return self._append_entry(
"security_event",
{
"event_type": event_type,
"description": description,
"severity": severity,
"metadata": metadata or {},
},
)
def validate_audit_log_integrity(log_path: Path) -> tuple[bool, list[str]]:
"""Validate audit log hash chain integrity.
Args:
log_path: Path to an audit log file.
Returns:
Tuple of (is_valid, errors).
"""
if not log_path.exists():
return True, []
errors: list[str] = []
prev_hash = _GENESIS_HASH
for line_number, raw_line in enumerate(
log_path.read_text(encoding="utf-8").splitlines(), start=1
):
if not raw_line.strip():
continue
try:
entry = json.loads(raw_line)
except json.JSONDecodeError:
errors.append(f"line {line_number}: invalid JSON")
continue
line_prev_hash = entry.get("prev_hash")
entry_hash = entry.get("entry_hash")
if line_prev_hash != prev_hash:
errors.append(f"line {line_number}: prev_hash mismatch")
# Recompute hash after removing the stored entry hash.
cloned = dict(entry)
cloned.pop("entry_hash", None)
expected_hash = AuditLogger._compute_entry_hash(prev_hash, cloned)
if entry_hash != expected_hash:
errors.append(f"line {line_number}: entry_hash mismatch")
prev_hash = expected_hash
else:
prev_hash = str(entry_hash)
return len(errors) == 0, errors
_audit_logger: AuditLogger | None = None
def get_audit_logger() -> AuditLogger:
"""Get or create global audit logger instance."""
global _audit_logger
if _audit_logger is None:
_audit_logger = AuditLogger()
return _audit_logger
def reset_audit_logger() -> None:
"""Reset global audit logger instance (primarily for testing)."""
global _audit_logger
if _audit_logger is not None:
_audit_logger.close()
_audit_logger = None