From abef368a68bdb9c7a13c9d744181cca3eb50ecf1 Mon Sep 17 00:00:00 2001 From: latte Date: Sat, 17 Jan 2026 21:57:04 +0100 Subject: [PATCH] update --- {.github => .gitea}/workflows/ci.yml | 71 ---- .gitea/workflows/dependency-updates.yml | 44 ++ .github/workflows/dependency-updates.yml | 75 ---- IMPLEMENTATION_PLAN.md | 400 ------------------ README.md | 14 +- .../20260117_add_banned_word_metadata.py | 37 ++ .../versions/20260117_enable_ai_defaults.py | 41 ++ pyproject.toml | 2 + src/guardden/bot.py | 22 +- src/guardden/cogs/wordlist_sync.py | 38 ++ src/guardden/config.py | 98 ++++- src/guardden/dashboard/__main__.py | 16 + src/guardden/models/guild.py | 21 +- src/guardden/services/automod.py | 128 +++--- src/guardden/services/guild_config.py | 6 + src/guardden/services/wordlist.py | 180 ++++++++ tests/conftest.py | 60 +-- tests/test_config.py | 63 +-- tests/test_database_integration.py | 118 ++---- 19 files changed, 677 insertions(+), 757 deletions(-) rename {.github => .gitea}/workflows/ci.yml (69%) create mode 100644 .gitea/workflows/dependency-updates.yml delete mode 100644 .github/workflows/dependency-updates.yml delete mode 100644 IMPLEMENTATION_PLAN.md create mode 100644 migrations/versions/20260117_add_banned_word_metadata.py create mode 100644 migrations/versions/20260117_enable_ai_defaults.py create mode 100644 src/guardden/cogs/wordlist_sync.py create mode 100644 src/guardden/dashboard/__main__.py create mode 100644 src/guardden/services/wordlist.py diff --git a/.github/workflows/ci.yml b/.gitea/workflows/ci.yml similarity index 69% rename from .github/workflows/ci.yml rename to .gitea/workflows/ci.yml index 87541a7..dcdd62d 100644 --- a/.github/workflows/ci.yml +++ b/.gitea/workflows/ci.yml @@ -136,7 +136,6 @@ jobs: GUARDDEN_AI_PROVIDER: "none" GUARDDEN_LOG_LEVEL: "DEBUG" run: | - # Run database migrations for tests python -c " import os os.environ['GUARDDEN_DISCORD_TOKEN'] = 'test_token_12345678901234567890123456789012345' @@ -153,15 +152,6 @@ jobs: run: | pytest --cov=src/guardden --cov-report=xml --cov-report=html --cov-report=term-missing - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 - if: matrix.python-version == '3.11' - with: - file: ./coverage.xml - flags: unittests - name: codecov-umbrella - fail_ci_if_error: false - - name: Upload coverage reports uses: actions/upload-artifact@v3 if: matrix.python-version == '3.11' @@ -207,64 +197,3 @@ jobs: - name: Test Docker image run: | docker run --rm guardden:${{ github.sha }} python -m guardden --help - - deploy-staging: - name: Deploy to Staging - runs-on: ubuntu-latest - needs: [code-quality, test, build-docker] - if: github.ref == 'refs/heads/develop' && github.event_name == 'push' - environment: staging - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Deploy to staging - run: | - echo "Deploying to staging environment..." - echo "This would typically involve:" - echo "- Pushing Docker image to registry" - echo "- Updating Kubernetes/Docker Compose configs" - echo "- Running database migrations" - echo "- Performing health checks" - - deploy-production: - name: Deploy to Production - runs-on: ubuntu-latest - needs: [code-quality, test, build-docker] - if: github.event_name == 'release' - environment: production - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Deploy to production - run: | - echo "Deploying to production environment..." - echo "This would typically involve:" - echo "- Pushing Docker image to registry with version tag" - echo "- Blue/green deployment or rolling update" - echo "- Running database migrations" - echo "- Performing comprehensive health checks" - echo "- Monitoring deployment success" - - notification: - name: Notification - runs-on: ubuntu-latest - needs: [code-quality, test, build-docker] - if: always() - steps: - - name: Notify on failure - if: contains(needs.*.result, 'failure') - run: | - echo "Pipeline failed. In a real environment, this would:" - echo "- Send notifications to Discord/Slack" - echo "- Create GitHub issue for investigation" - echo "- Alert the development team" - - - name: Notify on success - if: needs.code-quality.result == 'success' && needs.test.result == 'success' && needs.build-docker.result == 'success' - run: | - echo "Pipeline succeeded! In a real environment, this would:" - echo "- Send success notification" - echo "- Update deployment status" - echo "- Trigger downstream processes" \ No newline at end of file diff --git a/.gitea/workflows/dependency-updates.yml b/.gitea/workflows/dependency-updates.yml new file mode 100644 index 0000000..e3b5e68 --- /dev/null +++ b/.gitea/workflows/dependency-updates.yml @@ -0,0 +1,44 @@ +name: Dependency Updates + +on: + schedule: + - cron: '0 9 * * 1' + workflow_dispatch: + +jobs: + update-dependencies: + name: Update Dependencies + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + + - name: Install pip-tools + run: | + python -m pip install --upgrade pip + pip install pip-tools + + - name: Update dependencies + run: | + pip-compile --upgrade pyproject.toml --output-file requirements.txt + pip-compile --upgrade --extra dev pyproject.toml --output-file requirements-dev.txt + + - name: Check for security vulnerabilities + run: | + pip install safety + safety check --file requirements.txt --json --output vulnerability-report.json || true + safety check --file requirements-dev.txt --json --output vulnerability-dev-report.json || true + + - name: Upload vulnerability reports + uses: actions/upload-artifact@v3 + if: always() + with: + name: vulnerability-reports + path: | + vulnerability-report.json + vulnerability-dev-report.json diff --git a/.github/workflows/dependency-updates.yml b/.github/workflows/dependency-updates.yml deleted file mode 100644 index 5144c26..0000000 --- a/.github/workflows/dependency-updates.yml +++ /dev/null @@ -1,75 +0,0 @@ -name: Dependency Updates - -on: - schedule: - # Run weekly on Mondays at 9 AM UTC - - cron: '0 9 * * 1' - workflow_dispatch: # Allow manual triggering - -jobs: - update-dependencies: - name: Update Dependencies - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - token: ${{ secrets.GITHUB_TOKEN }} - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: "3.11" - - - name: Install pip-tools - run: | - python -m pip install --upgrade pip - pip install pip-tools - - - name: Update dependencies - run: | - # Generate requirements files from pyproject.toml - pip-compile --upgrade pyproject.toml --output-file requirements.txt - pip-compile --upgrade --extra dev pyproject.toml --output-file requirements-dev.txt - - - name: Check for security vulnerabilities - run: | - pip install safety - safety check --file requirements.txt --json --output vulnerability-report.json || true - safety check --file requirements-dev.txt --json --output vulnerability-dev-report.json || true - - - name: Create Pull Request - uses: peter-evans/create-pull-request@v5 - with: - token: ${{ secrets.GITHUB_TOKEN }} - commit-message: 'chore: update dependencies' - title: 'Automated dependency updates' - body: | - ## Automated Dependency Updates - - This PR contains automated dependency updates generated by the dependency update workflow. - - ### Changes - - Updated all dependencies to latest compatible versions - - Checked for security vulnerabilities - - ### Security Scan Results - Please review the uploaded security scan artifacts for any vulnerabilities. - - ### Testing - - [ ] All tests pass - - [ ] No breaking changes introduced - - [ ] Security scan results reviewed - - **Note**: This is an automated PR. Please review all changes carefully before merging. - branch: automated/dependency-updates - delete-branch: true - - - name: Upload vulnerability reports - uses: actions/upload-artifact@v3 - if: always() - with: - name: vulnerability-reports - path: | - vulnerability-report.json - vulnerability-dev-report.json \ No newline at end of file diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md deleted file mode 100644 index 1dbeaaf..0000000 --- a/IMPLEMENTATION_PLAN.md +++ /dev/null @@ -1,400 +0,0 @@ -# GuardDen Enhancement Implementation Plan - -## 🎯 Executive Summary - -Your GuardDen bot is well-architected with solid fundamentals, but needs: -1. **Critical security and bug fixes** (immediate priority) -2. **Comprehensive testing infrastructure** for reliability -3. **Modern DevOps pipeline** for sustainable development -4. **Enhanced dashboard** with real-time analytics and management capabilities - -## 📋 Implementation Roadmap - -### **Phase 1: Foundation & Security (Week 1-2)** ✅ COMPLETED -*Critical bugs, security fixes, and testing infrastructure* - -#### 1.1 Critical Security Fixes ✅ COMPLETED -- [x] **Fix configuration validation** in `src/guardden/config.py:11-45` - - Added strict Discord ID parsing with regex validation - - Implemented minimum secret key length enforcement - - Added input sanitization and validation for all configuration fields -- [x] **Secure error handling** throughout Discord API calls - - Added proper error handling for kick/ban/timeout operations - - Implemented graceful fallback for Discord API failures -- [x] **Add input sanitization** for URL parsing in automod service - - Enhanced URL validation with length limits and character filtering - - Improved normalize_domain function with security checks - - Updated URL pattern for more restrictive matching -- [x] **Database security audit** and add missing indexes - - Created comprehensive migration with 25+ indexes - - Added indexes for all common query patterns and foreign keys - -#### 1.2 Error Handling Improvements ✅ COMPLETED -- [x] **Refactor exception handling** in `src/guardden/bot.py:119-123` - - Improved cog loading with specific exception types - - Added better error context and logging - - Enhanced guild initialization error handling -- [x] **Add circuit breakers** for problematic regex patterns - - Implemented RegexCircuitBreaker class with timeout protection - - Added pattern validation to prevent catastrophic backtracking - - Integrated safe regex execution throughout automod service -- [x] **Implement graceful degradation** for AI service failures - - Enhanced error handling in existing AI integration -- [x] **Add proper error feedback** for Discord API failures - - Added user-friendly error messages for moderation failures - - Implemented fallback responses when embed sending fails - -#### 1.3 Testing Infrastructure ✅ COMPLETED -- [x] **Set up pytest configuration** with async support and coverage - - Created comprehensive conftest.py with 20+ fixtures - - Added pytest.ini with coverage requirements (75%+ threshold) - - Configured async test support and proper markers -- [x] **Create test fixtures** for database, Discord mocks, AI providers - - Database fixtures with in-memory SQLite - - Complete Discord mock objects (users, guilds, channels, messages) - - Test configuration and environment setup -- [x] **Add integration tests** for all cogs and services - - Created test_config.py for configuration security validation - - Created test_automod_security.py for automod security improvements - - Created test_database_integration.py for database model testing -- [x] **Implement test database** with proper isolation - - In-memory SQLite setup for test isolation - - Automatic table creation and cleanup - - Session management for tests - -### **Phase 2: DevOps & CI/CD (Week 2-3)** ✅ COMPLETED -*Automated testing, deployment, and monitoring* - -#### 2.1 CI/CD Pipeline ✅ COMPLETED -- [x] **GitHub Actions workflow** for automated testing - - Comprehensive CI pipeline with code quality, security scanning, and testing - - Multi-Python version testing (3.11, 3.12) with PostgreSQL service - - Automated dependency updates with security vulnerability scanning - - Deployment pipelines for staging and production environments -- [x] **Multi-stage Docker builds** with optional AI dependencies - - Optimized Dockerfile with builder pattern for reduced image size - - Configurable AI dependency installation with build args - - Development stage with debugging tools and hot reloading - - Proper security practices (non-root user, health checks) -- [x] **Automated security scanning** with dependency checks - - Safety for dependency vulnerability scanning - - Bandit for security linting of Python code - - Integrated into CI pipeline with artifact reporting -- [x] **Code quality gates** with ruff, mypy, and coverage thresholds - - 75%+ test coverage requirement with detailed reporting - - Strict type checking with mypy - - Code formatting and linting with ruff - -#### 2.2 Monitoring & Logging ✅ COMPLETED -- [x] **Structured logging** with JSON formatter - - Optional structlog integration for enhanced structured logging - - Graceful fallback to stdlib logging when structlog unavailable - - Context-aware logging with command tracing and performance metrics - - Configurable log levels and JSON formatting for production -- [x] **Application metrics** with Prometheus/OpenTelemetry - - Comprehensive metrics collection (commands, moderation, AI, database) - - Optional Prometheus integration with graceful degradation - - Grafana dashboards and monitoring stack configuration - - Performance monitoring with request duration and error tracking -- [x] **Health check improvements** for database and AI providers - - Comprehensive health check system with database, AI, and Discord API monitoring - - CLI health check tool with JSON output support - - Docker health checks integrated into container definitions - - System metrics collection (CPU, memory, disk usage) -- [x] **Error tracking and monitoring** infrastructure - - Structured logging with error context and stack traces - - Metrics-based monitoring for error rates and performance - - Health check system for proactive issue detection - -#### 2.3 Development Environment ✅ COMPLETED -- [x] **Docker Compose improvements** with dev overrides - - Comprehensive docker-compose.yml with production-ready configuration - - Development overrides with hot reloading and debugging support - - Integrated monitoring stack (Prometheus, Grafana, Redis, PostgreSQL) - - Development tools (PgAdmin, Redis Commander, MailHog) -- [x] **Development automation and tooling** - - Comprehensive development script (scripts/dev.sh) with 15+ commands - - Automated setup, testing, linting, and deployment workflows - - Database migration management and health checking tools -- [x] **Development documentation and setup guides** - - Complete Docker setup with development and production configurations - - Automated environment setup and dependency management - - Comprehensive development workflow documentation - -### **Phase 3: Dashboard Backend Enhancement (Week 3-4)** ✅ COMPLETED -*Expand API capabilities for comprehensive management* - -#### 3.1 Enhanced API Endpoints ✅ COMPLETED -- [x] **Real-time analytics API** (`/api/analytics/*`) - - Moderation action statistics - - User activity metrics - - AI performance data - - Server health metrics - -#### 3.2 User Management API ✅ COMPLETED -- [x] **User profile endpoints** (`/api/users/*`) -- [x] **Strike and note management** -- [x] **User search and filtering** - -#### 3.3 Configuration Management API ✅ COMPLETED -- [x] **Guild settings management** (`/api/guilds/{id}/settings`) -- [x] **Automod rule configuration** (`/api/guilds/{id}/automod`) -- [x] **AI provider settings** per guild -- [x] **Export/import functionality** for settings - -#### 3.4 WebSocket Support ✅ COMPLETED -- [x] **Real-time event streaming** for live updates -- [x] **Live moderation feed** for active monitoring -- [x] **System alerts and notifications** - -### **Phase 4: React Dashboard Frontend (Week 4-6)** ✅ COMPLETED -*Modern, responsive web interface with real-time capabilities* - -#### 4.1 Frontend Architecture ✅ COMPLETED -``` -dashboard-frontend/ -├── src/ -│ ├── components/ # Reusable UI components (Layout) -│ ├── pages/ # Page components (Dashboard, Analytics, Users, Settings, Moderation) -│ ├── services/ # API clients and WebSocket -│ ├── types/ # TypeScript definitions -│ └── index.css # Tailwind styles -├── public/ # Static assets -└── package.json # Dependencies and scripts -``` - -#### 4.2 Key Features ✅ COMPLETED -- [x] **Authentication Flow**: Dual OAuth with session management -- [x] **Real-time Analytics Dashboard**: - - Live metrics with charts (Recharts) - - Moderation activity timeline - - AI performance monitoring -- [x] **User Management Interface**: - - User search and profiles - - Strike history display -- [x] **Guild Configuration**: - - Settings management forms - - Automod rule builder - - AI sensitivity configuration -- [x] **Export functionality**: JSON configuration export - -#### 4.3 Technical Stack ✅ COMPLETED -- [x] **React 18** with TypeScript and Vite -- [x] **Tailwind CSS** for responsive design -- [x] **React Query** for API state management -- [x] **React Hook Form** for form handling -- [x] **React Router** for navigation -- [x] **WebSocket client** for real-time updates -- [x] **Recharts** for data visualization -- [x] **date-fns** for date formatting - -### **Phase 5: Performance & Scalability (Week 6-7)** ✅ COMPLETED -*Optimize performance and prepare for scaling* - -#### 5.1 Database Optimization ✅ COMPLETED -- [x] **Add strategic indexes** for common query patterns (analytics tables) -- [x] **Database migration for analytics models** with comprehensive indexing - -#### 5.2 Application Performance ✅ COMPLETED -- [x] **Implement Redis caching** for guild configs with in-memory fallback -- [x] **Multi-tier caching system** (memory + Redis) -- [x] **Cache service** with automatic TTL management - -#### 5.3 Architecture Improvements ✅ COMPLETED -- [x] **Analytics tracking system** with dedicated models -- [x] **Caching abstraction layer** for flexible cache backends -- [x] **Performance-optimized guild config service** - -## 🛠 Technical Specifications - -### Enhanced Dashboard Features - -#### Real-time Analytics Dashboard -```typescript -interface AnalyticsData { - moderationStats: { - totalActions: number; - actionsByType: Record; - actionsOverTime: TimeSeriesData[]; - }; - userActivity: { - activeUsers: number; - newJoins: number; - messageVolume: number; - }; - aiPerformance: { - accuracy: number; - falsePositives: number; - responseTime: number; - }; -} -``` - -#### User Management Interface -- **Advanced search** with filters (username, join date, strike count) -- **Bulk actions** (mass ban, mass role assignment) -- **User timeline** showing all interactions with the bot -- **Note system** for moderator communications - -#### Notification System -```typescript -interface Alert { - id: string; - type: 'security' | 'moderation' | 'system'; - severity: 'low' | 'medium' | 'high' | 'critical'; - message: string; - guildId?: string; - timestamp: Date; - acknowledged: boolean; -} -``` - -### API Enhancements - -#### WebSocket Events -```python -# Real-time events -class WebSocketEvent(BaseModel): - type: str # "moderation_action", "user_join", "ai_alert" - guild_id: int - timestamp: datetime - data: dict -``` - -#### New Endpoints -```python -# Analytics endpoints -GET /api/analytics/summary -GET /api/analytics/moderation-stats -GET /api/analytics/user-activity -GET /api/analytics/ai-performance - -# User management -GET /api/users/search -GET /api/users/{user_id}/profile -POST /api/users/{user_id}/note -POST /api/users/bulk-action - -# Configuration -GET /api/guilds/{guild_id}/settings -PUT /api/guilds/{guild_id}/settings -GET /api/guilds/{guild_id}/automod-rules -POST /api/guilds/{guild_id}/automod-rules - -# Real-time updates -WebSocket /ws/events -``` - -## 📊 Success Metrics - -### Code Quality -- **Test Coverage**: 90%+ for all modules -- **Type Coverage**: 95%+ with mypy strict mode -- **Security Score**: Zero critical vulnerabilities -- **Performance**: <100ms API response times - -### Dashboard Functionality -- **Real-time Updates**: <1 second latency for events -- **User Experience**: Mobile-responsive, accessible design -- **Data Export**: Multiple format support (CSV, JSON, PDF) -- **Uptime**: 99.9% availability target - -## 🚀 Implementation Status - -- **Phase 1**: ✅ COMPLETED -- **Phase 2**: ✅ COMPLETED -- **Phase 3**: ✅ COMPLETED -- **Phase 4**: ✅ COMPLETED -- **Phase 5**: ✅ COMPLETED - ---- -*Last Updated: January 17, 2026* - -## 📊 Phase 1 Achievements - -### Security Enhancements -- **Configuration Security**: Implemented strict validation for Discord IDs, API keys, and all configuration parameters -- **Input Sanitization**: Enhanced URL parsing with comprehensive validation and filtering -- **Database Security**: Added 25+ strategic indexes for performance and security -- **Regex Security**: Implemented circuit breaker pattern to prevent catastrophic backtracking - -### Code Quality Improvements -- **Error Handling**: Comprehensive error handling throughout Discord API calls and bot operations -- **Type Safety**: Resolved major type annotation issues and improved code clarity -- **Testing Infrastructure**: Complete test suite setup with 75%+ coverage requirements - -### Performance Optimizations -- **Database Indexing**: Strategic indexes for all common query patterns -- **Regex Optimization**: Safe regex execution with timeout protection -- **Memory Management**: Improved spam tracking with proper cleanup - -### Developer Experience -- **Test Coverage**: Comprehensive test fixtures and integration tests -- **Documentation**: Updated implementation plan and inline documentation -- **Configuration**: Enhanced validation and better error messages - -## 📊 Phase 2 Achievements - -### DevOps Infrastructure -- **CI/CD Pipeline**: Complete GitHub Actions workflow with parallel job execution -- **Docker Optimization**: Multi-stage builds reducing image size by ~40% -- **Security Automation**: Automated vulnerability scanning and dependency management -- **Quality Gates**: 75%+ test coverage requirement with comprehensive type checking - -### Monitoring & Observability -- **Structured Logging**: JSON logging with context-aware tracing -- **Metrics Collection**: 15+ Prometheus metrics for comprehensive monitoring -- **Health Checks**: Multi-service health monitoring with performance tracking -- **Dashboard Integration**: Grafana dashboards for real-time monitoring - -### Development Experience -- **One-Command Setup**: `./scripts/dev.sh setup` for complete environment setup -- **Hot Reloading**: Development containers with live code reloading -- **Database Tools**: Automated migration management and admin interfaces -- **Comprehensive Tooling**: 15+ development commands for testing, linting, and deployment - -## 📊 Phase 3-5 Achievements - -### Phase 3: Dashboard Backend Enhancement -- **Analytics API**: Comprehensive real-time analytics with moderation stats, user activity, and AI performance tracking -- **User Management**: Full CRUD API for user profiles, notes, and search functionality -- **Configuration API**: Guild settings and automod configuration with export/import support -- **WebSocket Support**: Real-time event streaming with automatic reconnection and heartbeat - -**New API Endpoints:** -- `/api/analytics/summary` - Complete analytics overview -- `/api/analytics/moderation-stats` - Detailed moderation statistics -- `/api/analytics/user-activity` - User activity metrics -- `/api/analytics/ai-performance` - AI moderation performance -- `/api/users/search` - User search with filters -- `/api/users/{id}/profile` - User profile details -- `/api/users/{id}/notes` - User notes management -- `/api/guilds/{id}/settings` - Guild settings CRUD -- `/api/guilds/{id}/automod` - Automod configuration -- `/api/guilds/{id}/export` - Configuration export -- `/ws/events` - WebSocket real-time events - -### Phase 4: React Dashboard Frontend -- **Modern UI**: Tailwind CSS-based responsive design with dark mode support -- **Real-time Charts**: Recharts integration for moderation analytics and trends -- **Smart Caching**: React Query for intelligent data fetching and caching -- **Type Safety**: Full TypeScript coverage with comprehensive type definitions - -**Pages Implemented:** -- Dashboard - Overview with key metrics and charts -- Analytics - Detailed statistics and trends -- Users - User search and management -- Moderation - Comprehensive log viewing -- Settings - Guild configuration management - -### Phase 5: Performance & Scalability -- **Multi-tier Caching**: Redis + in-memory caching with automatic fallback -- **Analytics Models**: Dedicated database models for AI checks, user activity, and message stats -- **Optimized Queries**: Strategic indexes on all analytics tables -- **Flexible Architecture**: Cache abstraction supporting multiple backends - -**Performance Improvements:** -- Guild config caching reduces database load by ~80% -- Analytics queries optimized with proper indexing -- WebSocket connections with efficient heartbeat mechanism -- In-memory fallback ensures reliability without Redis \ No newline at end of file diff --git a/README.md b/README.md index 9e65d06..dd3863e 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ GuardDen is a comprehensive Discord moderation bot designed to protect your comm 1. Clone the repository: ```bash - git clone https://github.com/yourusername/guardden.git + git clone https://git.hiddenden.cafe/Hiddenden/GuardDen.git cd guardden ``` @@ -155,6 +155,9 @@ GuardDen is a comprehensive Discord moderation bot designed to protect your comm | `GUARDDEN_DASHBOARD_OWNER_DISCORD_ID` | Discord user ID allowed | Required | | `GUARDDEN_DASHBOARD_OWNER_ENTRA_OBJECT_ID` | Entra object ID allowed | Required | | `GUARDDEN_DASHBOARD_CORS_ORIGINS` | Dashboard CORS origins | (empty = none) | +| `GUARDDEN_WORDLIST_ENABLED` | Enable managed wordlist sync | `true` | +| `GUARDDEN_WORDLIST_UPDATE_HOURS` | Managed wordlist sync interval | `168` | +| `GUARDDEN_WORDLIST_SOURCES` | JSON array of wordlist sources | (empty = defaults) | ### Per-Guild Settings @@ -208,6 +211,10 @@ Each server can configure: | `!bannedwords add [action] [is_regex]` | Add a banned word | | `!bannedwords remove ` | Remove a banned word by ID | +Managed wordlists are synced weekly by default. You can override sources with +`GUARDDEN_WORDLIST_SOURCES` (JSON array) or disable syncing entirely with +`GUARDDEN_WORDLIST_ENABLED=false`. + ### Automod | Command | Description | @@ -262,6 +269,11 @@ The dashboard provides read-only visibility into moderation logs across all serv - Entra: `http://localhost:8080/auth/entra/callback` - Discord: `http://localhost:8080/auth/discord/callback` +## CI (Gitea Actions) + +Workflows live under `.gitea/workflows/` and mirror the previous GitHub Actions +pipeline for linting, tests, and Docker builds. + ## Project Structure ``` diff --git a/migrations/versions/20260117_add_banned_word_metadata.py b/migrations/versions/20260117_add_banned_word_metadata.py new file mode 100644 index 0000000..2ff4f75 --- /dev/null +++ b/migrations/versions/20260117_add_banned_word_metadata.py @@ -0,0 +1,37 @@ +"""Add metadata fields for managed banned words. + +Revision ID: 20260117_add_banned_word_metadata +Revises: 20260117_enable_ai_defaults +Create Date: 2026-01-17 21:15:00.000000 +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "20260117_add_banned_word_metadata" +down_revision = "20260117_enable_ai_defaults" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "banned_words", + sa.Column("source", sa.String(length=100), nullable=True), + ) + op.add_column( + "banned_words", + sa.Column("category", sa.String(length=20), nullable=True), + ) + op.add_column( + "banned_words", + sa.Column("managed", sa.Boolean(), nullable=False, server_default=sa.text("false")), + ) + op.alter_column("banned_words", "managed", server_default=None) + + +def downgrade() -> None: + op.drop_column("banned_words", "managed") + op.drop_column("banned_words", "category") + op.drop_column("banned_words", "source") diff --git a/migrations/versions/20260117_enable_ai_defaults.py b/migrations/versions/20260117_enable_ai_defaults.py new file mode 100644 index 0000000..3da3f83 --- /dev/null +++ b/migrations/versions/20260117_enable_ai_defaults.py @@ -0,0 +1,41 @@ +"""Enable AI moderation defaults for existing guilds. + +Revision ID: 20260117_enable_ai_defaults +Revises: 20260117_analytics +Create Date: 2026-01-17 21:00:00.000000 +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "20260117_enable_ai_defaults" +down_revision = "20260117_analytics" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.execute( + sa.text( + """ + UPDATE guild_settings + SET ai_moderation_enabled = TRUE, + nsfw_detection_enabled = TRUE, + ai_sensitivity = 80 + """ + ) + ) + + +def downgrade() -> None: + op.execute( + sa.text( + """ + UPDATE guild_settings + SET ai_moderation_enabled = FALSE, + nsfw_detection_enabled = FALSE, + ai_sensitivity = 50 + """ + ) + ) diff --git a/pyproject.toml b/pyproject.toml index 8f5be78..59dd2ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "uvicorn>=0.27.0", "authlib>=1.3.0", "httpx>=0.27.0", + "itsdangerous>=2.1.2", ] [project.optional-dependencies] @@ -39,6 +40,7 @@ dev = [ "pytest>=7.4.0", "pytest-asyncio>=0.23.0", "pytest-cov>=4.1.0", + "aiosqlite>=0.19.0", "ruff>=0.1.0", "mypy>=1.7.0", "pre-commit>=3.6.0", diff --git a/src/guardden/bot.py b/src/guardden/bot.py index e806e0e..36787b9 100644 --- a/src/guardden/bot.py +++ b/src/guardden/bot.py @@ -42,6 +42,7 @@ class GuardDen(commands.Bot): self.database = Database(settings) self.guild_config: "GuildConfigService | None" = None self.ai_provider: AIProvider | None = None + self.wordlist_service = None self.rate_limiter = RateLimiter() async def _get_prefix(self, bot: "GuardDen", message: discord.Message) -> list[str]: @@ -90,6 +91,9 @@ class GuardDen(commands.Bot): from guardden.services.guild_config import GuildConfigService self.guild_config = GuildConfigService(self.database) + from guardden.services.wordlist import WordlistService + + self.wordlist_service = WordlistService(self.database, self.settings) # Initialize AI provider api_key = None @@ -115,6 +119,7 @@ class GuardDen(commands.Bot): "guardden.cogs.ai_moderation", "guardden.cogs.verification", "guardden.cogs.health", + "guardden.cogs.wordlist_sync", ] failed_cogs = [] @@ -131,7 +136,7 @@ class GuardDen(commands.Bot): except Exception as e: logger.error(f"Unexpected error loading cog {cog}: {e}", exc_info=True) failed_cogs.append(cog) - + if failed_cogs: logger.warning(f"Failed to load {len(failed_cogs)} cog(s): {', '.join(failed_cogs)}") # Don't fail startup if some cogs fail to load, but log it prominently @@ -146,7 +151,7 @@ class GuardDen(commands.Bot): if self.guild_config: initialized = 0 failed_guilds = [] - + for guild in self.guilds: try: if not self.is_guild_allowed(guild.id): @@ -162,12 +167,17 @@ class GuardDen(commands.Bot): await self.guild_config.create_guild(guild) initialized += 1 except Exception as e: - logger.error(f"Failed to initialize config for guild {guild.id} ({guild.name}): {e}", exc_info=True) + logger.error( + f"Failed to initialize config for guild {guild.id} ({guild.name}): {e}", + exc_info=True, + ) failed_guilds.append(guild.id) logger.info("Initialized config for %s guild(s)", initialized) if failed_guilds: - logger.warning(f"Failed to initialize {len(failed_guilds)} guild(s): {failed_guilds}") + logger.warning( + f"Failed to initialize {len(failed_guilds)} guild(s): {failed_guilds}" + ) # Set presence activity = discord.Activity( @@ -206,9 +216,7 @@ class GuardDen(commands.Bot): logger.info(f"Joined guild: {guild.name} (ID: {guild.id})") if not self.is_guild_allowed(guild.id): - logger.warning( - "Guild %s (ID: %s) not in allowlist, leaving.", guild.name, guild.id - ) + logger.warning("Guild %s (ID: %s) not in allowlist, leaving.", guild.name, guild.id) await guild.leave() return diff --git a/src/guardden/cogs/wordlist_sync.py b/src/guardden/cogs/wordlist_sync.py new file mode 100644 index 0000000..0cd239b --- /dev/null +++ b/src/guardden/cogs/wordlist_sync.py @@ -0,0 +1,38 @@ +"""Background task for managed wordlist syncing.""" + +import logging + +from discord.ext import commands, tasks + +from guardden.services.wordlist import WordlistService + +logger = logging.getLogger(__name__) + + +class WordlistSync(commands.Cog): + """Periodic sync of managed wordlists into guild bans.""" + + def __init__(self, bot: commands.Bot, service: WordlistService) -> None: + self.bot = bot + self.service = service + self.sync_task.change_interval(hours=service.update_interval.total_seconds() / 3600) + self.sync_task.start() + + def cog_unload(self) -> None: + self.sync_task.cancel() + + @tasks.loop(hours=1) + async def sync_task(self) -> None: + await self.service.sync_all() + + @sync_task.before_loop + async def before_sync_task(self) -> None: + await self.bot.wait_until_ready() + + +async def setup(bot: commands.Bot) -> None: + service = getattr(bot, "wordlist_service", None) + if not service: + logger.warning("Wordlist service not initialized; skipping sync task") + return + await bot.add_cog(WordlistSync(bot, service)) diff --git a/src/guardden/config.py b/src/guardden/config.py index 101ee8d..69c1add 100644 --- a/src/guardden/config.py +++ b/src/guardden/config.py @@ -5,9 +5,9 @@ import re from pathlib import Path from typing import Any, Literal -from pydantic import Field, SecretStr, field_validator, ValidationError +from pydantic import BaseModel, Field, SecretStr, ValidationError, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict - +from pydantic_settings.sources import EnvSettingsSource # Discord snowflake ID validation regex (64-bit integers, 17-19 digits) DISCORD_ID_PATTERN = re.compile(r"^\d{17,19}$") @@ -19,17 +19,17 @@ def _validate_discord_id(value: str | int) -> int: id_str = str(value) else: id_str = str(value).strip() - + # Check format if not DISCORD_ID_PATTERN.match(id_str): raise ValueError(f"Invalid Discord ID format: {id_str}") - + # Convert to int and validate range discord_id = int(id_str) # Discord snowflakes are 64-bit integers, minimum valid ID is around 2010 if discord_id < 100000000000000000 or discord_id > 9999999999999999999: raise ValueError(f"Discord ID out of valid range: {discord_id}") - + return discord_id @@ -65,6 +65,27 @@ def _parse_id_list(value: Any) -> list[int]: return parsed +class GuardDenEnvSettingsSource(EnvSettingsSource): + """Environment settings source with safe list parsing.""" + + def decode_complex_value(self, field_name: str, field, value: Any): + if field_name in {"allowed_guilds", "owner_ids"} and isinstance(value, str): + return value + return super().decode_complex_value(field_name, field, value) + + +class WordlistSourceConfig(BaseModel): + """Configuration for a managed wordlist source.""" + + name: str + url: str + category: Literal["hard", "soft", "context"] + action: Literal["delete", "warn", "strike"] + reason: str + is_regex: bool = False + enabled: bool = True + + class Settings(BaseSettings): """Application settings loaded from environment variables.""" @@ -73,8 +94,25 @@ class Settings(BaseSettings): env_file_encoding="utf-8", case_sensitive=False, env_prefix="GUARDDEN_", + env_parse_none_str="", ) + @classmethod + def settings_customise_sources( + cls, + settings_cls, + init_settings, + env_settings, + dotenv_settings, + file_secret_settings, + ): + return ( + init_settings, + GuardDenEnvSettingsSource(settings_cls), + dotenv_settings, + file_secret_settings, + ) + # Discord settings discord_token: SecretStr = Field(..., description="Discord bot token") discord_prefix: str = Field(default="!", description="Default command prefix") @@ -114,11 +152,43 @@ class Settings(BaseSettings): # Paths data_dir: Path = Field(default=Path("data"), description="Data directory for persistent files") + # Wordlist sync + wordlist_enabled: bool = Field( + default=True, description="Enable automatic managed wordlist syncing" + ) + wordlist_update_hours: int = Field( + default=168, description="Managed wordlist sync interval in hours" + ) + wordlist_sources: list[WordlistSourceConfig] = Field( + default_factory=list, + description="Managed wordlist sources (JSON array via env overrides)", + ) + @field_validator("allowed_guilds", "owner_ids", mode="before") @classmethod def _validate_id_list(cls, value: Any) -> list[int]: return _parse_id_list(value) + @field_validator("wordlist_sources", mode="before") + @classmethod + def _parse_wordlist_sources(cls, value: Any) -> list[WordlistSourceConfig]: + if value is None: + return [] + if isinstance(value, list): + return [WordlistSourceConfig.model_validate(item) for item in value] + if isinstance(value, str): + text = value.strip() + if not text: + return [] + try: + data = json.loads(text) + except json.JSONDecodeError as exc: + raise ValueError("Invalid JSON for wordlist_sources") from exc + if not isinstance(data, list): + raise ValueError("wordlist_sources must be a JSON array") + return [WordlistSourceConfig.model_validate(item) for item in data] + return [] + @field_validator("discord_token") @classmethod def _validate_discord_token(cls, value: SecretStr) -> SecretStr: @@ -126,11 +196,11 @@ class Settings(BaseSettings): token = value.get_secret_value() if not token: raise ValueError("Discord token cannot be empty") - + # Basic Discord token format validation (not perfect but catches common issues) if len(token) < 50 or not re.match(r"^[A-Za-z0-9._-]+$", token): raise ValueError("Invalid Discord token format") - + return value @field_validator("anthropic_api_key", "openai_api_key") @@ -139,15 +209,15 @@ class Settings(BaseSettings): """Validate API key format if provided.""" if value is None: return None - + key = value.get_secret_value() if not key: return None - + # Basic API key validation if len(key) < 20: raise ValueError("API key too short to be valid") - + return value def validate_configuration(self) -> None: @@ -157,17 +227,21 @@ class Settings(BaseSettings): raise ValueError("GUARDDEN_ANTHROPIC_API_KEY is required when AI provider is anthropic") if self.ai_provider == "openai" and not self.openai_api_key: raise ValueError("GUARDDEN_OPENAI_API_KEY is required when AI provider is openai") - + # Database pool validation if self.database_pool_min > self.database_pool_max: raise ValueError("database_pool_min cannot be greater than database_pool_max") if self.database_pool_min < 1: raise ValueError("database_pool_min must be at least 1") - + # Data directory validation if not isinstance(self.data_dir, Path): raise ValueError("data_dir must be a valid path") + # Wordlist validation + if self.wordlist_update_hours < 1: + raise ValueError("wordlist_update_hours must be at least 1") + def get_settings() -> Settings: """Get application settings instance.""" diff --git a/src/guardden/dashboard/__main__.py b/src/guardden/dashboard/__main__.py new file mode 100644 index 0000000..c2a7bd3 --- /dev/null +++ b/src/guardden/dashboard/__main__.py @@ -0,0 +1,16 @@ +"""Dashboard entrypoint for `python -m guardden.dashboard`.""" + +import os + +import uvicorn + + +def main() -> None: + host = os.getenv("GUARDDEN_DASHBOARD_HOST", "0.0.0.0") + port = int(os.getenv("GUARDDEN_DASHBOARD_PORT", "8000")) + log_level = os.getenv("GUARDDEN_LOG_LEVEL", "info").lower() + uvicorn.run("guardden.dashboard.main:app", host=host, port=port, log_level=log_level) + + +if __name__ == "__main__": + main() diff --git a/src/guardden/models/guild.py b/src/guardden/models/guild.py index efbf5ed..878d6d4 100644 --- a/src/guardden/models/guild.py +++ b/src/guardden/models/guild.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import TYPE_CHECKING -from sqlalchemy import Boolean, Float, ForeignKey, Integer, String, Text +from sqlalchemy import JSON, Boolean, Float, ForeignKey, Integer, String, Text from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -59,7 +59,9 @@ class GuildSettings(Base, TimestampMixin): # Role configuration mute_role_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True) verified_role_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True) - mod_role_ids: Mapped[dict] = mapped_column(JSONB, default=list, nullable=False) + mod_role_ids: Mapped[dict] = mapped_column( + JSONB().with_variant(JSON(), "sqlite"), default=list, nullable=False + ) # Moderation settings automod_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) @@ -73,11 +75,13 @@ class GuildSettings(Base, TimestampMixin): mention_limit: Mapped[int] = mapped_column(Integer, default=5, nullable=False) mention_rate_limit: Mapped[int] = mapped_column(Integer, default=10, nullable=False) mention_rate_window: Mapped[int] = mapped_column(Integer, default=60, nullable=False) - scam_allowlist: Mapped[list[str]] = mapped_column(JSONB, default=list, nullable=False) + scam_allowlist: Mapped[list[str]] = mapped_column( + JSONB().with_variant(JSON(), "sqlite"), default=list, nullable=False + ) # Strike thresholds (actions at each threshold) strike_actions: Mapped[dict] = mapped_column( - JSONB, + JSONB().with_variant(JSON(), "sqlite"), default=lambda: { "1": {"action": "warn"}, "3": {"action": "timeout", "duration": 3600}, @@ -88,11 +92,11 @@ class GuildSettings(Base, TimestampMixin): ) # AI moderation settings - ai_moderation_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) - ai_sensitivity: Mapped[int] = mapped_column(Integer, default=50, nullable=False) # 0-100 scale + ai_moderation_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + ai_sensitivity: Mapped[int] = mapped_column(Integer, default=80, nullable=False) # 0-100 scale ai_confidence_threshold: Mapped[float] = mapped_column(Float, default=0.7, nullable=False) ai_log_only: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) - nsfw_detection_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + nsfw_detection_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) # Verification settings verification_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) @@ -120,6 +124,9 @@ class BannedWord(Base, TimestampMixin): String(20), default="delete", nullable=False ) # delete, warn, strike reason: Mapped[str | None] = mapped_column(Text, nullable=True) + source: Mapped[str | None] = mapped_column(String(100), nullable=True) + category: Mapped[str | None] = mapped_column(String(20), nullable=True) + managed: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # Who added this and when added_by: Mapped[int] = mapped_column(SnowflakeID, nullable=False) diff --git a/src/guardden/services/automod.py b/src/guardden/services/automod.py index 322a626..98eb14e 100644 --- a/src/guardden/services/automod.py +++ b/src/guardden/services/automod.py @@ -7,7 +7,7 @@ import time from collections import defaultdict from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone -from typing import NamedTuple, Sequence, TYPE_CHECKING +from typing import TYPE_CHECKING, NamedTuple, Sequence from urllib.parse import urlparse if TYPE_CHECKING: @@ -16,6 +16,7 @@ else: try: import discord # type: ignore except ModuleNotFoundError: # pragma: no cover + class _DiscordStub: class Message: # minimal stub for type hints pass @@ -26,120 +27,122 @@ from guardden.models.guild import BannedWord logger = logging.getLogger(__name__) + # Circuit breaker for regex safety class RegexTimeoutError(Exception): """Raised when regex execution takes too long.""" + pass class RegexCircuitBreaker: """Circuit breaker to prevent catastrophic backtracking in regex patterns.""" - + def __init__(self, timeout_seconds: float = 0.1): self.timeout_seconds = timeout_seconds self.failed_patterns: dict[str, datetime] = {} self.failure_threshold = timedelta(minutes=5) # Disable pattern for 5 minutes after failure - + def _timeout_handler(self, signum, frame): """Signal handler for regex timeout.""" raise RegexTimeoutError("Regex execution timed out") - + def is_pattern_disabled(self, pattern: str) -> bool: """Check if a pattern is temporarily disabled due to timeouts.""" if pattern not in self.failed_patterns: return False - + failure_time = self.failed_patterns[pattern] if datetime.now(timezone.utc) - failure_time > self.failure_threshold: # Re-enable the pattern after threshold time del self.failed_patterns[pattern] return False - + return True - + def safe_regex_search(self, pattern: str, text: str, flags: int = 0) -> bool: """Safely execute regex search with timeout protection.""" if self.is_pattern_disabled(pattern): logger.warning(f"Regex pattern temporarily disabled due to timeout: {pattern[:50]}...") return False - + # Basic pattern validation to catch obviously problematic patterns if self._is_dangerous_pattern(pattern): logger.warning(f"Potentially dangerous regex pattern rejected: {pattern[:50]}...") return False - + old_handler = None try: # Set up timeout signal (Unix systems only) - if hasattr(signal, 'SIGALRM'): + if hasattr(signal, "SIGALRM"): old_handler = signal.signal(signal.SIGALRM, self._timeout_handler) signal.alarm(int(self.timeout_seconds * 1000)) # Convert to milliseconds - + start_time = time.perf_counter() - + # Compile and execute regex compiled_pattern = re.compile(pattern, flags) result = bool(compiled_pattern.search(text)) - + execution_time = time.perf_counter() - start_time - + # Log slow patterns for monitoring if execution_time > self.timeout_seconds * 0.8: logger.warning( f"Slow regex pattern (took {execution_time:.3f}s): {pattern[:50]}..." ) - + return result - + except RegexTimeoutError: # Pattern took too long, disable it temporarily self.failed_patterns[pattern] = datetime.now(timezone.utc) logger.error(f"Regex pattern timed out and disabled: {pattern[:50]}...") return False - + except re.error as e: logger.warning(f"Invalid regex pattern '{pattern[:50]}...': {e}") return False - + except Exception as e: logger.error(f"Unexpected error in regex execution: {e}") return False - + finally: # Clean up timeout signal - if hasattr(signal, 'SIGALRM') and old_handler is not None: + if hasattr(signal, "SIGALRM") and old_handler is not None: signal.alarm(0) signal.signal(signal.SIGALRM, old_handler) - + def _is_dangerous_pattern(self, pattern: str) -> bool: """Basic heuristic to detect potentially dangerous regex patterns.""" # Check for patterns that are commonly problematic dangerous_indicators = [ - r'(\w+)+', # Nested quantifiers - r'(\d+)+', # Nested quantifiers on digits - r'(.+)+', # Nested quantifiers on anything - r'(.*)+', # Nested quantifiers on anything (greedy) - r'(\w*)+', # Nested quantifiers with * - r'(\S+)+', # Nested quantifiers on non-whitespace + r"(\w+)+", # Nested quantifiers + r"(\d+)+", # Nested quantifiers on digits + r"(.+)+", # Nested quantifiers on anything + r"(.*)+", # Nested quantifiers on anything (greedy) + r"(\w*)+", # Nested quantifiers with * + r"(\S+)+", # Nested quantifiers on non-whitespace ] - + # Check for excessively long patterns if len(pattern) > 500: return True - + # Check for nested quantifiers (simplified detection) - if '+)+' in pattern or '*)+' in pattern or '?)+' in pattern: + if "+)+" in pattern or "*)+" in pattern or "?)+" in pattern: return True - + # Check for excessive repetition operators - if pattern.count('+') > 10 or pattern.count('*') > 10: + if pattern.count("+") > 10 or pattern.count("*") > 10: return True - + # Check for specific dangerous patterns for dangerous in dangerous_indicators: if dangerous in pattern: return True - + return False @@ -240,34 +243,43 @@ def normalize_domain(value: str) -> str: """Normalize a domain or URL for allowlist checks with security validation.""" if not value or not isinstance(value, str): return "" - + + if any(char in value for char in ["\x00", "\n", "\r", "\t"]): + return "" + text = value.strip().lower() if not text or len(text) > 2000: # Prevent excessively long URLs return "" - - # Sanitize input to prevent injection attacks - if any(char in text for char in ['\x00', '\n', '\r', '\t']): - return "" - + try: if "://" not in text: text = f"http://{text}" - + parsed = urlparse(text) hostname = parsed.hostname or "" - + # Additional validation for hostname if not hostname or len(hostname) > 253: # RFC limit return "" - + # Check for malicious patterns - if any(char in hostname for char in [' ', '\x00', '\n', '\r', '\t']): + if any(char in hostname for char in [" ", "\x00", "\n", "\r", "\t"]): return "" - + + if not re.fullmatch(r"[a-z0-9.-]+", hostname): + return "" + if hostname.startswith(".") or hostname.endswith(".") or ".." in hostname: + return "" + for label in hostname.split("."): + if not label: + return "" + if label.startswith("-") or label.endswith("-"): + return "" + # Remove www prefix if hostname.startswith("www."): hostname = hostname[4:] - + return hostname except (ValueError, UnicodeError, Exception): # urlparse can raise various exceptions with malicious input @@ -305,13 +317,13 @@ class AutomodService: # Normalize: lowercase, remove extra spaces, remove special chars # Use simple string operations for basic patterns to avoid regex overhead normalized = content.lower() - + # Remove special characters (simplified approach) - normalized = ''.join(c for c in normalized if c.isalnum() or c.isspace()) - + normalized = "".join(c for c in normalized if c.isalnum() or c.isspace()) + # Normalize whitespace - normalized = ' '.join(normalized.split()) - + normalized = " ".join(normalized.split()) + return normalized def check_banned_words( @@ -369,14 +381,14 @@ class AutomodService: # Limit URL length to prevent processing extremely long URLs if len(url) > 2000: continue - + url_lower = url.lower() hostname = normalize_domain(url) - + # Skip if hostname normalization failed (security check) if not hostname: continue - + if allowlist_set and is_allowed_domain(hostname, allowlist_set): continue @@ -540,3 +552,11 @@ class AutomodService: def cleanup_guild(self, guild_id: int) -> None: """Remove all tracking data for a guild.""" self._spam_trackers.pop(guild_id, None) + + +_automod_service = AutomodService() + + +def detect_scam_links(content: str, allowlist: list[str] | None = None) -> AutomodResult | None: + """Convenience wrapper for scam detection.""" + return _automod_service.check_scam_links(content, allowlist) diff --git a/src/guardden/services/guild_config.py b/src/guardden/services/guild_config.py index 9e911af..92f6249 100644 --- a/src/guardden/services/guild_config.py +++ b/src/guardden/services/guild_config.py @@ -141,6 +141,9 @@ class GuildConfigService: is_regex: bool = False, action: str = "delete", reason: str | None = None, + source: str | None = None, + category: str | None = None, + managed: bool = False, ) -> BannedWord: """Add a banned word to a guild.""" async with self.database.session() as session: @@ -150,6 +153,9 @@ class GuildConfigService: is_regex=is_regex, action=action, reason=reason, + source=source, + category=category, + managed=managed, added_by=added_by, ) session.add(banned_word) diff --git a/src/guardden/services/wordlist.py b/src/guardden/services/wordlist.py new file mode 100644 index 0000000..af6dca2 --- /dev/null +++ b/src/guardden/services/wordlist.py @@ -0,0 +1,180 @@ +"""Managed wordlist sync service.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Iterable + +import httpx +from sqlalchemy import delete, select + +from guardden.config import Settings, WordlistSourceConfig +from guardden.models import BannedWord, Guild +from guardden.services.database import Database + +logger = logging.getLogger(__name__) + +MAX_WORDLIST_ENTRY_LENGTH = 128 +REQUEST_TIMEOUT = 20.0 + + +@dataclass(frozen=True) +class WordlistSource: + name: str + url: str + category: str + action: str + reason: str + is_regex: bool = False + + +DEFAULT_SOURCES: list[WordlistSource] = [ + WordlistSource( + name="ldnoobw_en", + url="https://raw.githubusercontent.com/LDNOOBW/List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words/master/en", + category="soft", + action="warn", + reason="Auto list: profanity", + is_regex=False, + ), +] + + +def _normalize_entry(line: str) -> str: + text = line.strip().lower() + if not text: + return "" + if len(text) > MAX_WORDLIST_ENTRY_LENGTH: + return "" + return text + + +def _parse_wordlist(text: str) -> list[str]: + entries: list[str] = [] + seen: set[str] = set() + for raw in text.splitlines(): + line = raw.strip() + if not line: + continue + if line.startswith("#") or line.startswith("//") or line.startswith(";"): + continue + normalized = _normalize_entry(line) + if not normalized or normalized in seen: + continue + entries.append(normalized) + seen.add(normalized) + return entries + + +class WordlistService: + """Fetches and syncs managed wordlists into per-guild bans.""" + + def __init__(self, database: Database, settings: Settings) -> None: + self.database = database + self.settings = settings + self.sources = self._load_sources(settings) + self.update_interval = timedelta(hours=settings.wordlist_update_hours) + self.last_sync: datetime | None = None + + @staticmethod + def _load_sources(settings: Settings) -> list[WordlistSource]: + if settings.wordlist_sources: + sources: list[WordlistSource] = [] + for src in settings.wordlist_sources: + if not src.enabled: + continue + sources.append( + WordlistSource( + name=src.name, + url=src.url, + category=src.category, + action=src.action, + reason=src.reason, + is_regex=src.is_regex, + ) + ) + return sources + return list(DEFAULT_SOURCES) + + async def _fetch_source(self, source: WordlistSource) -> list[str]: + async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client: + response = await client.get(source.url) + response.raise_for_status() + return _parse_wordlist(response.text) + + async def sync_all(self) -> None: + if not self.settings.wordlist_enabled: + logger.info("Managed wordlist sync disabled") + return + if not self.sources: + logger.warning("No wordlist sources configured") + return + + logger.info("Starting managed wordlist sync (%d sources)", len(self.sources)) + async with self.database.session() as session: + guild_ids = list((await session.execute(select(Guild.id))).scalars().all()) + + for source in self.sources: + try: + entries = await self._fetch_source(source) + except Exception as exc: + logger.error("Failed to fetch wordlist %s: %s", source.name, exc) + continue + + if not entries: + logger.warning("Wordlist %s returned no entries", source.name) + continue + + await self._sync_source_to_guilds(source, entries, guild_ids) + + self.last_sync = datetime.now(timezone.utc) + logger.info("Managed wordlist sync completed") + + async def _sync_source_to_guilds( + self, source: WordlistSource, entries: Iterable[str], guild_ids: list[int] + ) -> None: + entry_set = set(entries) + async with self.database.session() as session: + for guild_id in guild_ids: + result = await session.execute( + select(BannedWord).where( + BannedWord.guild_id == guild_id, + BannedWord.managed.is_(True), + BannedWord.source == source.name, + ) + ) + existing = list(result.scalars().all()) + existing_set = {word.pattern.lower() for word in existing} + + to_add = entry_set - existing_set + to_remove = existing_set - entry_set + + if to_remove: + await session.execute( + delete(BannedWord).where( + BannedWord.guild_id == guild_id, + BannedWord.managed.is_(True), + BannedWord.source == source.name, + BannedWord.pattern.in_(to_remove), + ) + ) + + if to_add: + session.add_all( + [ + BannedWord( + guild_id=guild_id, + pattern=pattern, + is_regex=source.is_regex, + action=source.action, + reason=source.reason, + source=source.name, + category=source.category, + managed=True, + added_by=0, + ) + for pattern in to_add + ] + ) diff --git a/tests/conftest.py b/tests/conftest.py index 86c36e6..2c833ab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,11 +7,11 @@ import sys import tempfile from datetime import datetime, timezone from pathlib import Path -from unittest.mock import AsyncMock, MagicMock from typing import AsyncGenerator +from unittest.mock import AsyncMock, MagicMock import pytest -from sqlalchemy import create_engine +from sqlalchemy import create_engine, event, text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.pool import StaticPool @@ -23,7 +23,7 @@ if str(SRC_DIR) not in sys.path: # Import after path setup from guardden.config import Settings from guardden.models.base import Base -from guardden.models.guild import Guild, GuildSettings, BannedWord +from guardden.models.guild import BannedWord, Guild, GuildSettings from guardden.models.moderation import ModerationLog, Strike, UserNote from guardden.services.database import Database @@ -52,6 +52,7 @@ def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> bool | None: # Basic Test Fixtures # ============================================================================== + @pytest.fixture def sample_guild_id() -> int: """Return a sample Discord guild ID.""" @@ -80,11 +81,12 @@ def sample_owner_id() -> int: # Configuration Fixtures # ============================================================================== + @pytest.fixture def test_settings() -> Settings: """Return test configuration settings.""" return Settings( - discord_token="test_token_12345678901234567890", + discord_token="a" * 60, discord_prefix="!test", database_url="sqlite+aiosqlite:///test.db", database_pool_min=1, @@ -101,6 +103,7 @@ def test_settings() -> Settings: # Database Fixtures # ============================================================================== + @pytest.fixture async def test_database(test_settings: Settings) -> AsyncGenerator[Database, None]: """Create a test database with in-memory SQLite.""" @@ -111,19 +114,26 @@ async def test_database(test_settings: Settings) -> AsyncGenerator[Database, Non poolclass=StaticPool, echo=False, ) - + + @event.listens_for(engine.sync_engine, "connect") + def _enable_sqlite_foreign_keys(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + # Create all tables async with engine.begin() as conn: + await conn.execute(text("PRAGMA foreign_keys=ON")) await conn.run_sync(Base.metadata.create_all) - + database = Database(test_settings) database._engine = engine database._session_factory = async_sessionmaker( engine, class_=AsyncSession, expire_on_commit=False ) - + yield database - + await engine.dispose() @@ -138,10 +148,9 @@ async def db_session(test_database: Database) -> AsyncGenerator[AsyncSession, No # Model Fixtures # ============================================================================== + @pytest.fixture -async def test_guild( - db_session: AsyncSession, sample_guild_id: int, sample_owner_id: int -) -> Guild: +async def test_guild(db_session: AsyncSession, sample_guild_id: int, sample_owner_id: int) -> Guild: """Create a test guild with settings.""" guild = Guild( id=sample_guild_id, @@ -150,7 +159,7 @@ async def test_guild( premium=False, ) db_session.add(guild) - + # Create associated settings settings = GuildSettings( guild_id=sample_guild_id, @@ -160,7 +169,7 @@ async def test_guild( verification_enabled=False, ) db_session.add(settings) - + await db_session.commit() await db_session.refresh(guild) return guild @@ -187,10 +196,7 @@ async def test_banned_word( @pytest.fixture async def test_moderation_log( - db_session: AsyncSession, - test_guild: Guild, - sample_user_id: int, - sample_moderator_id: int + db_session: AsyncSession, test_guild: Guild, sample_user_id: int, sample_moderator_id: int ) -> ModerationLog: """Create a test moderation log entry.""" mod_log = ModerationLog( @@ -211,10 +217,7 @@ async def test_moderation_log( @pytest.fixture async def test_strike( - db_session: AsyncSession, - test_guild: Guild, - sample_user_id: int, - sample_moderator_id: int + db_session: AsyncSession, test_guild: Guild, sample_user_id: int, sample_moderator_id: int ) -> Strike: """Create a test strike.""" strike = Strike( @@ -236,6 +239,7 @@ async def test_strike( # Discord Mock Fixtures # ============================================================================== + @pytest.fixture def mock_discord_user(sample_user_id: int) -> MagicMock: """Create a mock Discord user.""" @@ -261,7 +265,7 @@ def mock_discord_member(mock_discord_user: MagicMock) -> MagicMock: member.avatar = mock_discord_user.avatar member.bot = mock_discord_user.bot member.send = mock_discord_user.send - + # Member-specific attributes member.guild = MagicMock() member.top_role = MagicMock() @@ -271,7 +275,7 @@ def mock_discord_member(mock_discord_user: MagicMock) -> MagicMock: member.kick = AsyncMock() member.ban = AsyncMock() member.timeout = AsyncMock() - + return member @@ -284,14 +288,14 @@ def mock_discord_guild(sample_guild_id: int, sample_owner_id: int) -> MagicMock: guild.owner_id = sample_owner_id guild.member_count = 100 guild.premium_tier = 0 - + # Methods guild.get_member = MagicMock(return_value=None) guild.get_channel = MagicMock(return_value=None) guild.leave = AsyncMock() guild.ban = AsyncMock() guild.unban = AsyncMock() - + return guild @@ -327,9 +331,7 @@ def mock_discord_message( @pytest.fixture def mock_discord_context( - mock_discord_member: MagicMock, - mock_discord_guild: MagicMock, - mock_discord_channel: MagicMock + mock_discord_member: MagicMock, mock_discord_guild: MagicMock, mock_discord_channel: MagicMock ) -> MagicMock: """Create a mock Discord command context.""" ctx = MagicMock() @@ -345,6 +347,7 @@ def mock_discord_context( # Bot and Service Fixtures # ============================================================================== + @pytest.fixture def mock_bot(test_database: Database) -> MagicMock: """Create a mock GuardDen bot.""" @@ -363,6 +366,7 @@ def mock_bot(test_database: Database) -> MagicMock: # Test Environment Setup # ============================================================================== + @pytest.fixture(autouse=True) def setup_test_environment() -> None: """Set up test environment variables.""" diff --git a/tests/test_config.py b/tests/test_config.py index 348e8b8..6c0850b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,7 +3,8 @@ import pytest from pydantic import ValidationError -from guardden.config import Settings, _parse_id_list, _validate_discord_id, normalize_domain +from guardden.config import Settings, _parse_id_list, _validate_discord_id +from guardden.services.automod import normalize_domain class TestDiscordIdValidation: @@ -17,7 +18,7 @@ class TestDiscordIdValidation: "1234567890123456789", # 19 digits 123456789012345678, # int format ] - + for valid_id in valid_ids: result = _validate_discord_id(valid_id) assert isinstance(result, int) @@ -35,7 +36,7 @@ class TestDiscordIdValidation: "0", # zero "-123456789012345678", # negative ] - + for invalid_id in invalid_ids: with pytest.raises(ValueError): _validate_discord_id(invalid_id) @@ -45,7 +46,7 @@ class TestDiscordIdValidation: # Too small (before Discord existed) with pytest.raises(ValueError): _validate_discord_id("99999999999999999") - + # Too large (exceeds 64-bit limit) with pytest.raises(ValueError): _validate_discord_id("99999999999999999999") @@ -64,7 +65,7 @@ class TestIdListParsing: ("", []), (None, []), ] - + for input_value, expected in test_cases: result = _parse_id_list(input_value) assert result == expected @@ -89,7 +90,7 @@ class TestIdListParsing: "123456789012345678\n234567890123456789", # newline "123456789012345678\r234567890123456789", # carriage return ] - + for malicious_input in malicious_inputs: result = _parse_id_list(malicious_input) # Should filter out malicious entries @@ -106,7 +107,7 @@ class TestSettingsValidation: "Bot.MTIzNDU2Nzg5MDEyMzQ1Njc4.some_long_token_string_here", "a" * 60, # minimum reasonable length ] - + for token in valid_tokens: settings = Settings(discord_token=token) assert settings.discord_token.get_secret_value() == token @@ -119,7 +120,7 @@ class TestSettingsValidation: "token with spaces", # contains spaces "token\nwith\nnewlines", # contains newlines ] - + for token in invalid_tokens: with pytest.raises(ValidationError): Settings(discord_token=token) @@ -131,7 +132,7 @@ class TestSettingsValidation: settings = Settings( discord_token="valid_token_" + "a" * 50, ai_provider="anthropic", - anthropic_api_key=valid_key + anthropic_api_key=valid_key, ) assert settings.anthropic_api_key.get_secret_value() == valid_key @@ -140,23 +141,23 @@ class TestSettingsValidation: Settings( discord_token="valid_token_" + "a" * 50, ai_provider="anthropic", - anthropic_api_key="short" + anthropic_api_key="short", ) def test_configuration_validation_ai_provider(self): """Test AI provider configuration validation.""" settings = Settings(discord_token="valid_token_" + "a" * 50) - + # Should pass with no AI provider settings.ai_provider = "none" settings.validate_configuration() - + # Should fail with anthropic but no key settings.ai_provider = "anthropic" settings.anthropic_api_key = None with pytest.raises(ValueError, match="GUARDDEN_ANTHROPIC_API_KEY is required"): settings.validate_configuration() - + # Should pass with anthropic and key settings.anthropic_api_key = "sk-" + "a" * 50 settings.validate_configuration() @@ -164,13 +165,13 @@ class TestSettingsValidation: def test_configuration_validation_database_pool(self): """Test database pool configuration validation.""" settings = Settings(discord_token="valid_token_" + "a" * 50) - + # Should fail with min > max settings.database_pool_min = 10 settings.database_pool_max = 5 with pytest.raises(ValueError, match="database_pool_min cannot be greater"): settings.validate_configuration() - + # Should fail with min < 1 settings.database_pool_min = 0 settings.database_pool_max = 5 @@ -190,7 +191,7 @@ class TestSecurityImprovements: "123456789012345678\x00\x01\x02", "123456789012345678", ] - + for attempt in injection_attempts: # Should either raise an error or filter out the malicious input try: @@ -205,33 +206,41 @@ class TestSecurityImprovements: def test_settings_with_malicious_env_vars(self): """Test that settings handle malicious environment variables.""" import os - + # Save original values original_guilds = os.environ.get("GUARDDEN_ALLOWED_GUILDS") original_owners = os.environ.get("GUARDDEN_OWNER_IDS") - + try: # Set malicious environment variables - os.environ["GUARDDEN_ALLOWED_GUILDS"] = "123456789012345678\x00,malicious" - os.environ["GUARDDEN_OWNER_IDS"] = "234567890123456789\n567890123456789012" - + try: + os.environ["GUARDDEN_ALLOWED_GUILDS"] = "123456789012345678\x00,malicious" + except ValueError: + os.environ["GUARDDEN_ALLOWED_GUILDS"] = "123456789012345678,malicious" + try: + os.environ["GUARDDEN_OWNER_IDS"] = "234567890123456789\n567890123456789012" + except ValueError: + os.environ["GUARDDEN_OWNER_IDS"] = "234567890123456789,567890123456789012" + settings = Settings(discord_token="valid_token_" + "a" * 50) - + # Should filter out malicious entries assert len(settings.allowed_guilds) <= 1 assert len(settings.owner_ids) <= 1 - + # Valid IDs should be preserved - assert 123456789012345678 in settings.allowed_guilds or len(settings.allowed_guilds) == 0 - + assert ( + 123456789012345678 in settings.allowed_guilds or len(settings.allowed_guilds) == 0 + ) + finally: # Restore original values if original_guilds is not None: os.environ["GUARDDEN_ALLOWED_GUILDS"] = original_guilds else: os.environ.pop("GUARDDEN_ALLOWED_GUILDS", None) - + if original_owners is not None: os.environ["GUARDDEN_OWNER_IDS"] = original_owners else: - os.environ.pop("GUARDDEN_OWNER_IDS", None) \ No newline at end of file + os.environ.pop("GUARDDEN_OWNER_IDS", None) diff --git a/tests/test_database_integration.py b/tests/test_database_integration.py index 4f3ba44..848ea12 100644 --- a/tests/test_database_integration.py +++ b/tests/test_database_integration.py @@ -1,10 +1,11 @@ """Tests for database integration and models.""" -import pytest from datetime import datetime, timezone + +import pytest from sqlalchemy import select -from guardden.models.guild import Guild, GuildSettings, BannedWord +from guardden.models.guild import BannedWord, Guild, GuildSettings from guardden.models.moderation import ModerationLog, Strike, UserNote from guardden.services.database import Database @@ -21,7 +22,7 @@ class TestDatabaseModels: premium=False, ) db_session.add(guild) - + settings = GuildSettings( guild_id=sample_guild_id, prefix="!", @@ -29,13 +30,13 @@ class TestDatabaseModels: ai_moderation_enabled=False, ) db_session.add(settings) - + await db_session.commit() - + # Test guild was created result = await db_session.execute(select(Guild).where(Guild.id == sample_guild_id)) created_guild = result.scalar_one() - + assert created_guild.id == sample_guild_id assert created_guild.name == "Test Guild" assert created_guild.owner_id == sample_owner_id @@ -44,11 +45,9 @@ class TestDatabaseModels: async def test_guild_settings_relationship(self, test_guild, db_session): """Test guild-settings relationship.""" # Load guild with settings - result = await db_session.execute( - select(Guild).where(Guild.id == test_guild.id) - ) + result = await db_session.execute(select(Guild).where(Guild.id == test_guild.id)) guild_with_settings = result.scalar_one() - + # Test relationship loading await db_session.refresh(guild_with_settings, ["settings"]) assert guild_with_settings.settings is not None @@ -67,24 +66,20 @@ class TestDatabaseModels: ) db_session.add(banned_word) await db_session.commit() - + # Verify creation result = await db_session.execute( select(BannedWord).where(BannedWord.guild_id == test_guild.id) ) created_word = result.scalar_one() - + assert created_word.pattern == "testbadword" assert not created_word.is_regex assert created_word.action == "delete" assert created_word.added_by == sample_moderator_id async def test_moderation_log_creation( - self, - test_guild, - db_session, - sample_user_id, - sample_moderator_id + self, test_guild, db_session, sample_user_id, sample_moderator_id ): """Test moderation log creation.""" mod_log = ModerationLog( @@ -99,24 +94,20 @@ class TestDatabaseModels: ) db_session.add(mod_log) await db_session.commit() - + # Verify creation result = await db_session.execute( select(ModerationLog).where(ModerationLog.guild_id == test_guild.id) ) created_log = result.scalar_one() - + assert created_log.action == "ban" assert created_log.target_id == sample_user_id assert created_log.moderator_id == sample_moderator_id assert not created_log.is_automatic async def test_strike_creation( - self, - test_guild, - db_session, - sample_user_id, - sample_moderator_id + self, test_guild, db_session, sample_user_id, sample_moderator_id ): """Test strike creation and tracking.""" strike = Strike( @@ -130,26 +121,19 @@ class TestDatabaseModels: ) db_session.add(strike) await db_session.commit() - + # Verify creation result = await db_session.execute( - select(Strike).where( - Strike.guild_id == test_guild.id, - Strike.user_id == sample_user_id - ) + select(Strike).where(Strike.guild_id == test_guild.id, Strike.user_id == sample_user_id) ) created_strike = result.scalar_one() - + assert created_strike.points == 1 assert created_strike.is_active assert created_strike.user_id == sample_user_id async def test_cascade_deletion( - self, - test_guild, - db_session, - sample_user_id, - sample_moderator_id + self, test_guild, db_session, sample_user_id, sample_moderator_id ): """Test that deleting a guild cascades to related records.""" # Add some related records @@ -160,7 +144,7 @@ class TestDatabaseModels: action="delete", added_by=sample_moderator_id, ) - + mod_log = ModerationLog( guild_id=test_guild.id, target_id=sample_user_id, @@ -171,7 +155,7 @@ class TestDatabaseModels: reason="Test warning", is_automatic=False, ) - + strike = Strike( guild_id=test_guild.id, user_id=sample_user_id, @@ -181,28 +165,26 @@ class TestDatabaseModels: points=1, is_active=True, ) - + db_session.add_all([banned_word, mod_log, strike]) await db_session.commit() - + # Delete the guild await db_session.delete(test_guild) await db_session.commit() - + # Verify related records were deleted banned_words = await db_session.execute( select(BannedWord).where(BannedWord.guild_id == test_guild.id) ) assert len(banned_words.scalars().all()) == 0 - + mod_logs = await db_session.execute( select(ModerationLog).where(ModerationLog.guild_id == test_guild.id) ) assert len(mod_logs.scalars().all()) == 0 - - strikes = await db_session.execute( - select(Strike).where(Strike.guild_id == test_guild.id) - ) + + strikes = await db_session.execute(select(Strike).where(Strike.guild_id == test_guild.id)) assert len(strikes.scalars().all()) == 0 @@ -210,11 +192,7 @@ class TestDatabaseIndexes: """Test that database indexes work as expected.""" async def test_moderation_log_indexes( - self, - test_guild, - db_session, - sample_user_id, - sample_moderator_id + self, test_guild, db_session, sample_user_id, sample_moderator_id ): """Test moderation log indexing for performance.""" # Create multiple moderation logs @@ -231,23 +209,23 @@ class TestDatabaseIndexes: is_automatic=bool(i % 2), ) logs.append(log) - + db_session.add_all(logs) await db_session.commit() - + # Test queries that should use indexes # Query by guild_id guild_logs = await db_session.execute( select(ModerationLog).where(ModerationLog.guild_id == test_guild.id) ) assert len(guild_logs.scalars().all()) == 10 - + # Query by target_id target_logs = await db_session.execute( select(ModerationLog).where(ModerationLog.target_id == sample_user_id) ) assert len(target_logs.scalars().all()) == 1 - + # Query by is_automatic auto_logs = await db_session.execute( select(ModerationLog).where(ModerationLog.is_automatic == True) @@ -255,11 +233,7 @@ class TestDatabaseIndexes: assert len(auto_logs.scalars().all()) == 5 async def test_strike_indexes( - self, - test_guild, - db_session, - sample_user_id, - sample_moderator_id + self, test_guild, db_session, sample_user_id, sample_moderator_id ): """Test strike indexing for performance.""" # Create multiple strikes @@ -275,18 +249,15 @@ class TestDatabaseIndexes: is_active=bool(i % 2), ) strikes.append(strike) - + db_session.add_all(strikes) await db_session.commit() - + # Test active strikes query active_strikes = await db_session.execute( - select(Strike).where( - Strike.guild_id == test_guild.id, - Strike.is_active == True - ) + select(Strike).where(Strike.guild_id == test_guild.id, Strike.is_active == True) ) - assert len(active_strikes.scalars().all()) == 3 # indices 1, 3 + assert len(active_strikes.scalars().all()) == 2 # indices 1, 3 class TestDatabaseSecurity: @@ -304,11 +275,9 @@ class TestDatabaseSecurity: ) db_session.add(guild) await db_session.commit() - + # Verify it was stored correctly - result = await db_session.execute( - select(Guild).where(Guild.id == valid_guild_id) - ) + result = await db_session.execute(select(Guild).where(Guild.id == valid_guild_id)) stored_guild = result.scalar_one() assert stored_guild.id == valid_guild_id @@ -321,13 +290,11 @@ class TestDatabaseSecurity: "' OR '1'='1", "", ] - + for malicious_input in malicious_inputs: # Try to use malicious input in a query # SQLAlchemy should prevent injection through parameterized queries - result = await db_session.execute( - select(Guild).where(Guild.name == malicious_input) - ) + result = await db_session.execute(select(Guild).where(Guild.name == malicious_input)) # Should not find anything (and not crash) assert result.scalar_one_or_none() is None @@ -343,4 +310,5 @@ class TestDatabaseSecurity: added_by=123456789012345678, ) db_session.add(banned_word) - await db_session.commit() \ No newline at end of file + await db_session.commit() + await db_session.rollback()