update
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 4m49s
CI/CD Pipeline / Security Scanning (push) Successful in 15s
CI/CD Pipeline / Tests (3.11) (push) Successful in 9m41s
CI/CD Pipeline / Tests (3.12) (push) Successful in 9m36s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
Dependency Updates / Update Dependencies (push) Successful in 29s
Some checks failed
CI/CD Pipeline / Code Quality Checks (push) Failing after 4m49s
CI/CD Pipeline / Security Scanning (push) Successful in 15s
CI/CD Pipeline / Tests (3.11) (push) Successful in 9m41s
CI/CD Pipeline / Tests (3.12) (push) Successful in 9m36s
CI/CD Pipeline / Build Docker Image (push) Has been skipped
Dependency Updates / Update Dependencies (push) Successful in 29s
This commit is contained in:
@@ -136,7 +136,6 @@ jobs:
|
|||||||
GUARDDEN_AI_PROVIDER: "none"
|
GUARDDEN_AI_PROVIDER: "none"
|
||||||
GUARDDEN_LOG_LEVEL: "DEBUG"
|
GUARDDEN_LOG_LEVEL: "DEBUG"
|
||||||
run: |
|
run: |
|
||||||
# Run database migrations for tests
|
|
||||||
python -c "
|
python -c "
|
||||||
import os
|
import os
|
||||||
os.environ['GUARDDEN_DISCORD_TOKEN'] = 'test_token_12345678901234567890123456789012345'
|
os.environ['GUARDDEN_DISCORD_TOKEN'] = 'test_token_12345678901234567890123456789012345'
|
||||||
@@ -153,15 +152,6 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pytest --cov=src/guardden --cov-report=xml --cov-report=html --cov-report=term-missing
|
pytest --cov=src/guardden --cov-report=xml --cov-report=html --cov-report=term-missing
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
|
||||||
uses: codecov/codecov-action@v3
|
|
||||||
if: matrix.python-version == '3.11'
|
|
||||||
with:
|
|
||||||
file: ./coverage.xml
|
|
||||||
flags: unittests
|
|
||||||
name: codecov-umbrella
|
|
||||||
fail_ci_if_error: false
|
|
||||||
|
|
||||||
- name: Upload coverage reports
|
- name: Upload coverage reports
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v3
|
||||||
if: matrix.python-version == '3.11'
|
if: matrix.python-version == '3.11'
|
||||||
@@ -207,64 +197,3 @@ jobs:
|
|||||||
- name: Test Docker image
|
- name: Test Docker image
|
||||||
run: |
|
run: |
|
||||||
docker run --rm guardden:${{ github.sha }} python -m guardden --help
|
docker run --rm guardden:${{ github.sha }} python -m guardden --help
|
||||||
|
|
||||||
deploy-staging:
|
|
||||||
name: Deploy to Staging
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
needs: [code-quality, test, build-docker]
|
|
||||||
if: github.ref == 'refs/heads/develop' && github.event_name == 'push'
|
|
||||||
environment: staging
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Deploy to staging
|
|
||||||
run: |
|
|
||||||
echo "Deploying to staging environment..."
|
|
||||||
echo "This would typically involve:"
|
|
||||||
echo "- Pushing Docker image to registry"
|
|
||||||
echo "- Updating Kubernetes/Docker Compose configs"
|
|
||||||
echo "- Running database migrations"
|
|
||||||
echo "- Performing health checks"
|
|
||||||
|
|
||||||
deploy-production:
|
|
||||||
name: Deploy to Production
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
needs: [code-quality, test, build-docker]
|
|
||||||
if: github.event_name == 'release'
|
|
||||||
environment: production
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Deploy to production
|
|
||||||
run: |
|
|
||||||
echo "Deploying to production environment..."
|
|
||||||
echo "This would typically involve:"
|
|
||||||
echo "- Pushing Docker image to registry with version tag"
|
|
||||||
echo "- Blue/green deployment or rolling update"
|
|
||||||
echo "- Running database migrations"
|
|
||||||
echo "- Performing comprehensive health checks"
|
|
||||||
echo "- Monitoring deployment success"
|
|
||||||
|
|
||||||
notification:
|
|
||||||
name: Notification
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
needs: [code-quality, test, build-docker]
|
|
||||||
if: always()
|
|
||||||
steps:
|
|
||||||
- name: Notify on failure
|
|
||||||
if: contains(needs.*.result, 'failure')
|
|
||||||
run: |
|
|
||||||
echo "Pipeline failed. In a real environment, this would:"
|
|
||||||
echo "- Send notifications to Discord/Slack"
|
|
||||||
echo "- Create GitHub issue for investigation"
|
|
||||||
echo "- Alert the development team"
|
|
||||||
|
|
||||||
- name: Notify on success
|
|
||||||
if: needs.code-quality.result == 'success' && needs.test.result == 'success' && needs.build-docker.result == 'success'
|
|
||||||
run: |
|
|
||||||
echo "Pipeline succeeded! In a real environment, this would:"
|
|
||||||
echo "- Send success notification"
|
|
||||||
echo "- Update deployment status"
|
|
||||||
echo "- Trigger downstream processes"
|
|
||||||
44
.gitea/workflows/dependency-updates.yml
Normal file
44
.gitea/workflows/dependency-updates.yml
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
name: Dependency Updates
|
||||||
|
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
- cron: '0 9 * * 1'
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
update-dependencies:
|
||||||
|
name: Update Dependencies
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
|
||||||
|
- name: Install pip-tools
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install pip-tools
|
||||||
|
|
||||||
|
- name: Update dependencies
|
||||||
|
run: |
|
||||||
|
pip-compile --upgrade pyproject.toml --output-file requirements.txt
|
||||||
|
pip-compile --upgrade --extra dev pyproject.toml --output-file requirements-dev.txt
|
||||||
|
|
||||||
|
- name: Check for security vulnerabilities
|
||||||
|
run: |
|
||||||
|
pip install safety
|
||||||
|
safety check --file requirements.txt --json --output vulnerability-report.json || true
|
||||||
|
safety check --file requirements-dev.txt --json --output vulnerability-dev-report.json || true
|
||||||
|
|
||||||
|
- name: Upload vulnerability reports
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
if: always()
|
||||||
|
with:
|
||||||
|
name: vulnerability-reports
|
||||||
|
path: |
|
||||||
|
vulnerability-report.json
|
||||||
|
vulnerability-dev-report.json
|
||||||
75
.github/workflows/dependency-updates.yml
vendored
75
.github/workflows/dependency-updates.yml
vendored
@@ -1,75 +0,0 @@
|
|||||||
name: Dependency Updates
|
|
||||||
|
|
||||||
on:
|
|
||||||
schedule:
|
|
||||||
# Run weekly on Mondays at 9 AM UTC
|
|
||||||
- cron: '0 9 * * 1'
|
|
||||||
workflow_dispatch: # Allow manual triggering
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
update-dependencies:
|
|
||||||
name: Update Dependencies
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: "3.11"
|
|
||||||
|
|
||||||
- name: Install pip-tools
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install pip-tools
|
|
||||||
|
|
||||||
- name: Update dependencies
|
|
||||||
run: |
|
|
||||||
# Generate requirements files from pyproject.toml
|
|
||||||
pip-compile --upgrade pyproject.toml --output-file requirements.txt
|
|
||||||
pip-compile --upgrade --extra dev pyproject.toml --output-file requirements-dev.txt
|
|
||||||
|
|
||||||
- name: Check for security vulnerabilities
|
|
||||||
run: |
|
|
||||||
pip install safety
|
|
||||||
safety check --file requirements.txt --json --output vulnerability-report.json || true
|
|
||||||
safety check --file requirements-dev.txt --json --output vulnerability-dev-report.json || true
|
|
||||||
|
|
||||||
- name: Create Pull Request
|
|
||||||
uses: peter-evans/create-pull-request@v5
|
|
||||||
with:
|
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
commit-message: 'chore: update dependencies'
|
|
||||||
title: 'Automated dependency updates'
|
|
||||||
body: |
|
|
||||||
## Automated Dependency Updates
|
|
||||||
|
|
||||||
This PR contains automated dependency updates generated by the dependency update workflow.
|
|
||||||
|
|
||||||
### Changes
|
|
||||||
- Updated all dependencies to latest compatible versions
|
|
||||||
- Checked for security vulnerabilities
|
|
||||||
|
|
||||||
### Security Scan Results
|
|
||||||
Please review the uploaded security scan artifacts for any vulnerabilities.
|
|
||||||
|
|
||||||
### Testing
|
|
||||||
- [ ] All tests pass
|
|
||||||
- [ ] No breaking changes introduced
|
|
||||||
- [ ] Security scan results reviewed
|
|
||||||
|
|
||||||
**Note**: This is an automated PR. Please review all changes carefully before merging.
|
|
||||||
branch: automated/dependency-updates
|
|
||||||
delete-branch: true
|
|
||||||
|
|
||||||
- name: Upload vulnerability reports
|
|
||||||
uses: actions/upload-artifact@v3
|
|
||||||
if: always()
|
|
||||||
with:
|
|
||||||
name: vulnerability-reports
|
|
||||||
path: |
|
|
||||||
vulnerability-report.json
|
|
||||||
vulnerability-dev-report.json
|
|
||||||
@@ -1,400 +0,0 @@
|
|||||||
# GuardDen Enhancement Implementation Plan
|
|
||||||
|
|
||||||
## 🎯 Executive Summary
|
|
||||||
|
|
||||||
Your GuardDen bot is well-architected with solid fundamentals, but needs:
|
|
||||||
1. **Critical security and bug fixes** (immediate priority)
|
|
||||||
2. **Comprehensive testing infrastructure** for reliability
|
|
||||||
3. **Modern DevOps pipeline** for sustainable development
|
|
||||||
4. **Enhanced dashboard** with real-time analytics and management capabilities
|
|
||||||
|
|
||||||
## 📋 Implementation Roadmap
|
|
||||||
|
|
||||||
### **Phase 1: Foundation & Security (Week 1-2)** ✅ COMPLETED
|
|
||||||
*Critical bugs, security fixes, and testing infrastructure*
|
|
||||||
|
|
||||||
#### 1.1 Critical Security Fixes ✅ COMPLETED
|
|
||||||
- [x] **Fix configuration validation** in `src/guardden/config.py:11-45`
|
|
||||||
- Added strict Discord ID parsing with regex validation
|
|
||||||
- Implemented minimum secret key length enforcement
|
|
||||||
- Added input sanitization and validation for all configuration fields
|
|
||||||
- [x] **Secure error handling** throughout Discord API calls
|
|
||||||
- Added proper error handling for kick/ban/timeout operations
|
|
||||||
- Implemented graceful fallback for Discord API failures
|
|
||||||
- [x] **Add input sanitization** for URL parsing in automod service
|
|
||||||
- Enhanced URL validation with length limits and character filtering
|
|
||||||
- Improved normalize_domain function with security checks
|
|
||||||
- Updated URL pattern for more restrictive matching
|
|
||||||
- [x] **Database security audit** and add missing indexes
|
|
||||||
- Created comprehensive migration with 25+ indexes
|
|
||||||
- Added indexes for all common query patterns and foreign keys
|
|
||||||
|
|
||||||
#### 1.2 Error Handling Improvements ✅ COMPLETED
|
|
||||||
- [x] **Refactor exception handling** in `src/guardden/bot.py:119-123`
|
|
||||||
- Improved cog loading with specific exception types
|
|
||||||
- Added better error context and logging
|
|
||||||
- Enhanced guild initialization error handling
|
|
||||||
- [x] **Add circuit breakers** for problematic regex patterns
|
|
||||||
- Implemented RegexCircuitBreaker class with timeout protection
|
|
||||||
- Added pattern validation to prevent catastrophic backtracking
|
|
||||||
- Integrated safe regex execution throughout automod service
|
|
||||||
- [x] **Implement graceful degradation** for AI service failures
|
|
||||||
- Enhanced error handling in existing AI integration
|
|
||||||
- [x] **Add proper error feedback** for Discord API failures
|
|
||||||
- Added user-friendly error messages for moderation failures
|
|
||||||
- Implemented fallback responses when embed sending fails
|
|
||||||
|
|
||||||
#### 1.3 Testing Infrastructure ✅ COMPLETED
|
|
||||||
- [x] **Set up pytest configuration** with async support and coverage
|
|
||||||
- Created comprehensive conftest.py with 20+ fixtures
|
|
||||||
- Added pytest.ini with coverage requirements (75%+ threshold)
|
|
||||||
- Configured async test support and proper markers
|
|
||||||
- [x] **Create test fixtures** for database, Discord mocks, AI providers
|
|
||||||
- Database fixtures with in-memory SQLite
|
|
||||||
- Complete Discord mock objects (users, guilds, channels, messages)
|
|
||||||
- Test configuration and environment setup
|
|
||||||
- [x] **Add integration tests** for all cogs and services
|
|
||||||
- Created test_config.py for configuration security validation
|
|
||||||
- Created test_automod_security.py for automod security improvements
|
|
||||||
- Created test_database_integration.py for database model testing
|
|
||||||
- [x] **Implement test database** with proper isolation
|
|
||||||
- In-memory SQLite setup for test isolation
|
|
||||||
- Automatic table creation and cleanup
|
|
||||||
- Session management for tests
|
|
||||||
|
|
||||||
### **Phase 2: DevOps & CI/CD (Week 2-3)** ✅ COMPLETED
|
|
||||||
*Automated testing, deployment, and monitoring*
|
|
||||||
|
|
||||||
#### 2.1 CI/CD Pipeline ✅ COMPLETED
|
|
||||||
- [x] **GitHub Actions workflow** for automated testing
|
|
||||||
- Comprehensive CI pipeline with code quality, security scanning, and testing
|
|
||||||
- Multi-Python version testing (3.11, 3.12) with PostgreSQL service
|
|
||||||
- Automated dependency updates with security vulnerability scanning
|
|
||||||
- Deployment pipelines for staging and production environments
|
|
||||||
- [x] **Multi-stage Docker builds** with optional AI dependencies
|
|
||||||
- Optimized Dockerfile with builder pattern for reduced image size
|
|
||||||
- Configurable AI dependency installation with build args
|
|
||||||
- Development stage with debugging tools and hot reloading
|
|
||||||
- Proper security practices (non-root user, health checks)
|
|
||||||
- [x] **Automated security scanning** with dependency checks
|
|
||||||
- Safety for dependency vulnerability scanning
|
|
||||||
- Bandit for security linting of Python code
|
|
||||||
- Integrated into CI pipeline with artifact reporting
|
|
||||||
- [x] **Code quality gates** with ruff, mypy, and coverage thresholds
|
|
||||||
- 75%+ test coverage requirement with detailed reporting
|
|
||||||
- Strict type checking with mypy
|
|
||||||
- Code formatting and linting with ruff
|
|
||||||
|
|
||||||
#### 2.2 Monitoring & Logging ✅ COMPLETED
|
|
||||||
- [x] **Structured logging** with JSON formatter
|
|
||||||
- Optional structlog integration for enhanced structured logging
|
|
||||||
- Graceful fallback to stdlib logging when structlog unavailable
|
|
||||||
- Context-aware logging with command tracing and performance metrics
|
|
||||||
- Configurable log levels and JSON formatting for production
|
|
||||||
- [x] **Application metrics** with Prometheus/OpenTelemetry
|
|
||||||
- Comprehensive metrics collection (commands, moderation, AI, database)
|
|
||||||
- Optional Prometheus integration with graceful degradation
|
|
||||||
- Grafana dashboards and monitoring stack configuration
|
|
||||||
- Performance monitoring with request duration and error tracking
|
|
||||||
- [x] **Health check improvements** for database and AI providers
|
|
||||||
- Comprehensive health check system with database, AI, and Discord API monitoring
|
|
||||||
- CLI health check tool with JSON output support
|
|
||||||
- Docker health checks integrated into container definitions
|
|
||||||
- System metrics collection (CPU, memory, disk usage)
|
|
||||||
- [x] **Error tracking and monitoring** infrastructure
|
|
||||||
- Structured logging with error context and stack traces
|
|
||||||
- Metrics-based monitoring for error rates and performance
|
|
||||||
- Health check system for proactive issue detection
|
|
||||||
|
|
||||||
#### 2.3 Development Environment ✅ COMPLETED
|
|
||||||
- [x] **Docker Compose improvements** with dev overrides
|
|
||||||
- Comprehensive docker-compose.yml with production-ready configuration
|
|
||||||
- Development overrides with hot reloading and debugging support
|
|
||||||
- Integrated monitoring stack (Prometheus, Grafana, Redis, PostgreSQL)
|
|
||||||
- Development tools (PgAdmin, Redis Commander, MailHog)
|
|
||||||
- [x] **Development automation and tooling**
|
|
||||||
- Comprehensive development script (scripts/dev.sh) with 15+ commands
|
|
||||||
- Automated setup, testing, linting, and deployment workflows
|
|
||||||
- Database migration management and health checking tools
|
|
||||||
- [x] **Development documentation and setup guides**
|
|
||||||
- Complete Docker setup with development and production configurations
|
|
||||||
- Automated environment setup and dependency management
|
|
||||||
- Comprehensive development workflow documentation
|
|
||||||
|
|
||||||
### **Phase 3: Dashboard Backend Enhancement (Week 3-4)** ✅ COMPLETED
|
|
||||||
*Expand API capabilities for comprehensive management*
|
|
||||||
|
|
||||||
#### 3.1 Enhanced API Endpoints ✅ COMPLETED
|
|
||||||
- [x] **Real-time analytics API** (`/api/analytics/*`)
|
|
||||||
- Moderation action statistics
|
|
||||||
- User activity metrics
|
|
||||||
- AI performance data
|
|
||||||
- Server health metrics
|
|
||||||
|
|
||||||
#### 3.2 User Management API ✅ COMPLETED
|
|
||||||
- [x] **User profile endpoints** (`/api/users/*`)
|
|
||||||
- [x] **Strike and note management**
|
|
||||||
- [x] **User search and filtering**
|
|
||||||
|
|
||||||
#### 3.3 Configuration Management API ✅ COMPLETED
|
|
||||||
- [x] **Guild settings management** (`/api/guilds/{id}/settings`)
|
|
||||||
- [x] **Automod rule configuration** (`/api/guilds/{id}/automod`)
|
|
||||||
- [x] **AI provider settings** per guild
|
|
||||||
- [x] **Export/import functionality** for settings
|
|
||||||
|
|
||||||
#### 3.4 WebSocket Support ✅ COMPLETED
|
|
||||||
- [x] **Real-time event streaming** for live updates
|
|
||||||
- [x] **Live moderation feed** for active monitoring
|
|
||||||
- [x] **System alerts and notifications**
|
|
||||||
|
|
||||||
### **Phase 4: React Dashboard Frontend (Week 4-6)** ✅ COMPLETED
|
|
||||||
*Modern, responsive web interface with real-time capabilities*
|
|
||||||
|
|
||||||
#### 4.1 Frontend Architecture ✅ COMPLETED
|
|
||||||
```
|
|
||||||
dashboard-frontend/
|
|
||||||
├── src/
|
|
||||||
│ ├── components/ # Reusable UI components (Layout)
|
|
||||||
│ ├── pages/ # Page components (Dashboard, Analytics, Users, Settings, Moderation)
|
|
||||||
│ ├── services/ # API clients and WebSocket
|
|
||||||
│ ├── types/ # TypeScript definitions
|
|
||||||
│ └── index.css # Tailwind styles
|
|
||||||
├── public/ # Static assets
|
|
||||||
└── package.json # Dependencies and scripts
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 4.2 Key Features ✅ COMPLETED
|
|
||||||
- [x] **Authentication Flow**: Dual OAuth with session management
|
|
||||||
- [x] **Real-time Analytics Dashboard**:
|
|
||||||
- Live metrics with charts (Recharts)
|
|
||||||
- Moderation activity timeline
|
|
||||||
- AI performance monitoring
|
|
||||||
- [x] **User Management Interface**:
|
|
||||||
- User search and profiles
|
|
||||||
- Strike history display
|
|
||||||
- [x] **Guild Configuration**:
|
|
||||||
- Settings management forms
|
|
||||||
- Automod rule builder
|
|
||||||
- AI sensitivity configuration
|
|
||||||
- [x] **Export functionality**: JSON configuration export
|
|
||||||
|
|
||||||
#### 4.3 Technical Stack ✅ COMPLETED
|
|
||||||
- [x] **React 18** with TypeScript and Vite
|
|
||||||
- [x] **Tailwind CSS** for responsive design
|
|
||||||
- [x] **React Query** for API state management
|
|
||||||
- [x] **React Hook Form** for form handling
|
|
||||||
- [x] **React Router** for navigation
|
|
||||||
- [x] **WebSocket client** for real-time updates
|
|
||||||
- [x] **Recharts** for data visualization
|
|
||||||
- [x] **date-fns** for date formatting
|
|
||||||
|
|
||||||
### **Phase 5: Performance & Scalability (Week 6-7)** ✅ COMPLETED
|
|
||||||
*Optimize performance and prepare for scaling*
|
|
||||||
|
|
||||||
#### 5.1 Database Optimization ✅ COMPLETED
|
|
||||||
- [x] **Add strategic indexes** for common query patterns (analytics tables)
|
|
||||||
- [x] **Database migration for analytics models** with comprehensive indexing
|
|
||||||
|
|
||||||
#### 5.2 Application Performance ✅ COMPLETED
|
|
||||||
- [x] **Implement Redis caching** for guild configs with in-memory fallback
|
|
||||||
- [x] **Multi-tier caching system** (memory + Redis)
|
|
||||||
- [x] **Cache service** with automatic TTL management
|
|
||||||
|
|
||||||
#### 5.3 Architecture Improvements ✅ COMPLETED
|
|
||||||
- [x] **Analytics tracking system** with dedicated models
|
|
||||||
- [x] **Caching abstraction layer** for flexible cache backends
|
|
||||||
- [x] **Performance-optimized guild config service**
|
|
||||||
|
|
||||||
## 🛠 Technical Specifications
|
|
||||||
|
|
||||||
### Enhanced Dashboard Features
|
|
||||||
|
|
||||||
#### Real-time Analytics Dashboard
|
|
||||||
```typescript
|
|
||||||
interface AnalyticsData {
|
|
||||||
moderationStats: {
|
|
||||||
totalActions: number;
|
|
||||||
actionsByType: Record<string, number>;
|
|
||||||
actionsOverTime: TimeSeriesData[];
|
|
||||||
};
|
|
||||||
userActivity: {
|
|
||||||
activeUsers: number;
|
|
||||||
newJoins: number;
|
|
||||||
messageVolume: number;
|
|
||||||
};
|
|
||||||
aiPerformance: {
|
|
||||||
accuracy: number;
|
|
||||||
falsePositives: number;
|
|
||||||
responseTime: number;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
#### User Management Interface
|
|
||||||
- **Advanced search** with filters (username, join date, strike count)
|
|
||||||
- **Bulk actions** (mass ban, mass role assignment)
|
|
||||||
- **User timeline** showing all interactions with the bot
|
|
||||||
- **Note system** for moderator communications
|
|
||||||
|
|
||||||
#### Notification System
|
|
||||||
```typescript
|
|
||||||
interface Alert {
|
|
||||||
id: string;
|
|
||||||
type: 'security' | 'moderation' | 'system';
|
|
||||||
severity: 'low' | 'medium' | 'high' | 'critical';
|
|
||||||
message: string;
|
|
||||||
guildId?: string;
|
|
||||||
timestamp: Date;
|
|
||||||
acknowledged: boolean;
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### API Enhancements
|
|
||||||
|
|
||||||
#### WebSocket Events
|
|
||||||
```python
|
|
||||||
# Real-time events
|
|
||||||
class WebSocketEvent(BaseModel):
|
|
||||||
type: str # "moderation_action", "user_join", "ai_alert"
|
|
||||||
guild_id: int
|
|
||||||
timestamp: datetime
|
|
||||||
data: dict
|
|
||||||
```
|
|
||||||
|
|
||||||
#### New Endpoints
|
|
||||||
```python
|
|
||||||
# Analytics endpoints
|
|
||||||
GET /api/analytics/summary
|
|
||||||
GET /api/analytics/moderation-stats
|
|
||||||
GET /api/analytics/user-activity
|
|
||||||
GET /api/analytics/ai-performance
|
|
||||||
|
|
||||||
# User management
|
|
||||||
GET /api/users/search
|
|
||||||
GET /api/users/{user_id}/profile
|
|
||||||
POST /api/users/{user_id}/note
|
|
||||||
POST /api/users/bulk-action
|
|
||||||
|
|
||||||
# Configuration
|
|
||||||
GET /api/guilds/{guild_id}/settings
|
|
||||||
PUT /api/guilds/{guild_id}/settings
|
|
||||||
GET /api/guilds/{guild_id}/automod-rules
|
|
||||||
POST /api/guilds/{guild_id}/automod-rules
|
|
||||||
|
|
||||||
# Real-time updates
|
|
||||||
WebSocket /ws/events
|
|
||||||
```
|
|
||||||
|
|
||||||
## 📊 Success Metrics
|
|
||||||
|
|
||||||
### Code Quality
|
|
||||||
- **Test Coverage**: 90%+ for all modules
|
|
||||||
- **Type Coverage**: 95%+ with mypy strict mode
|
|
||||||
- **Security Score**: Zero critical vulnerabilities
|
|
||||||
- **Performance**: <100ms API response times
|
|
||||||
|
|
||||||
### Dashboard Functionality
|
|
||||||
- **Real-time Updates**: <1 second latency for events
|
|
||||||
- **User Experience**: Mobile-responsive, accessible design
|
|
||||||
- **Data Export**: Multiple format support (CSV, JSON, PDF)
|
|
||||||
- **Uptime**: 99.9% availability target
|
|
||||||
|
|
||||||
## 🚀 Implementation Status
|
|
||||||
|
|
||||||
- **Phase 1**: ✅ COMPLETED
|
|
||||||
- **Phase 2**: ✅ COMPLETED
|
|
||||||
- **Phase 3**: ✅ COMPLETED
|
|
||||||
- **Phase 4**: ✅ COMPLETED
|
|
||||||
- **Phase 5**: ✅ COMPLETED
|
|
||||||
|
|
||||||
---
|
|
||||||
*Last Updated: January 17, 2026*
|
|
||||||
|
|
||||||
## 📊 Phase 1 Achievements
|
|
||||||
|
|
||||||
### Security Enhancements
|
|
||||||
- **Configuration Security**: Implemented strict validation for Discord IDs, API keys, and all configuration parameters
|
|
||||||
- **Input Sanitization**: Enhanced URL parsing with comprehensive validation and filtering
|
|
||||||
- **Database Security**: Added 25+ strategic indexes for performance and security
|
|
||||||
- **Regex Security**: Implemented circuit breaker pattern to prevent catastrophic backtracking
|
|
||||||
|
|
||||||
### Code Quality Improvements
|
|
||||||
- **Error Handling**: Comprehensive error handling throughout Discord API calls and bot operations
|
|
||||||
- **Type Safety**: Resolved major type annotation issues and improved code clarity
|
|
||||||
- **Testing Infrastructure**: Complete test suite setup with 75%+ coverage requirements
|
|
||||||
|
|
||||||
### Performance Optimizations
|
|
||||||
- **Database Indexing**: Strategic indexes for all common query patterns
|
|
||||||
- **Regex Optimization**: Safe regex execution with timeout protection
|
|
||||||
- **Memory Management**: Improved spam tracking with proper cleanup
|
|
||||||
|
|
||||||
### Developer Experience
|
|
||||||
- **Test Coverage**: Comprehensive test fixtures and integration tests
|
|
||||||
- **Documentation**: Updated implementation plan and inline documentation
|
|
||||||
- **Configuration**: Enhanced validation and better error messages
|
|
||||||
|
|
||||||
## 📊 Phase 2 Achievements
|
|
||||||
|
|
||||||
### DevOps Infrastructure
|
|
||||||
- **CI/CD Pipeline**: Complete GitHub Actions workflow with parallel job execution
|
|
||||||
- **Docker Optimization**: Multi-stage builds reducing image size by ~40%
|
|
||||||
- **Security Automation**: Automated vulnerability scanning and dependency management
|
|
||||||
- **Quality Gates**: 75%+ test coverage requirement with comprehensive type checking
|
|
||||||
|
|
||||||
### Monitoring & Observability
|
|
||||||
- **Structured Logging**: JSON logging with context-aware tracing
|
|
||||||
- **Metrics Collection**: 15+ Prometheus metrics for comprehensive monitoring
|
|
||||||
- **Health Checks**: Multi-service health monitoring with performance tracking
|
|
||||||
- **Dashboard Integration**: Grafana dashboards for real-time monitoring
|
|
||||||
|
|
||||||
### Development Experience
|
|
||||||
- **One-Command Setup**: `./scripts/dev.sh setup` for complete environment setup
|
|
||||||
- **Hot Reloading**: Development containers with live code reloading
|
|
||||||
- **Database Tools**: Automated migration management and admin interfaces
|
|
||||||
- **Comprehensive Tooling**: 15+ development commands for testing, linting, and deployment
|
|
||||||
|
|
||||||
## 📊 Phase 3-5 Achievements
|
|
||||||
|
|
||||||
### Phase 3: Dashboard Backend Enhancement
|
|
||||||
- **Analytics API**: Comprehensive real-time analytics with moderation stats, user activity, and AI performance tracking
|
|
||||||
- **User Management**: Full CRUD API for user profiles, notes, and search functionality
|
|
||||||
- **Configuration API**: Guild settings and automod configuration with export/import support
|
|
||||||
- **WebSocket Support**: Real-time event streaming with automatic reconnection and heartbeat
|
|
||||||
|
|
||||||
**New API Endpoints:**
|
|
||||||
- `/api/analytics/summary` - Complete analytics overview
|
|
||||||
- `/api/analytics/moderation-stats` - Detailed moderation statistics
|
|
||||||
- `/api/analytics/user-activity` - User activity metrics
|
|
||||||
- `/api/analytics/ai-performance` - AI moderation performance
|
|
||||||
- `/api/users/search` - User search with filters
|
|
||||||
- `/api/users/{id}/profile` - User profile details
|
|
||||||
- `/api/users/{id}/notes` - User notes management
|
|
||||||
- `/api/guilds/{id}/settings` - Guild settings CRUD
|
|
||||||
- `/api/guilds/{id}/automod` - Automod configuration
|
|
||||||
- `/api/guilds/{id}/export` - Configuration export
|
|
||||||
- `/ws/events` - WebSocket real-time events
|
|
||||||
|
|
||||||
### Phase 4: React Dashboard Frontend
|
|
||||||
- **Modern UI**: Tailwind CSS-based responsive design with dark mode support
|
|
||||||
- **Real-time Charts**: Recharts integration for moderation analytics and trends
|
|
||||||
- **Smart Caching**: React Query for intelligent data fetching and caching
|
|
||||||
- **Type Safety**: Full TypeScript coverage with comprehensive type definitions
|
|
||||||
|
|
||||||
**Pages Implemented:**
|
|
||||||
- Dashboard - Overview with key metrics and charts
|
|
||||||
- Analytics - Detailed statistics and trends
|
|
||||||
- Users - User search and management
|
|
||||||
- Moderation - Comprehensive log viewing
|
|
||||||
- Settings - Guild configuration management
|
|
||||||
|
|
||||||
### Phase 5: Performance & Scalability
|
|
||||||
- **Multi-tier Caching**: Redis + in-memory caching with automatic fallback
|
|
||||||
- **Analytics Models**: Dedicated database models for AI checks, user activity, and message stats
|
|
||||||
- **Optimized Queries**: Strategic indexes on all analytics tables
|
|
||||||
- **Flexible Architecture**: Cache abstraction supporting multiple backends
|
|
||||||
|
|
||||||
**Performance Improvements:**
|
|
||||||
- Guild config caching reduces database load by ~80%
|
|
||||||
- Analytics queries optimized with proper indexing
|
|
||||||
- WebSocket connections with efficient heartbeat mechanism
|
|
||||||
- In-memory fallback ensures reliability without Redis
|
|
||||||
14
README.md
14
README.md
@@ -86,7 +86,7 @@ GuardDen is a comprehensive Discord moderation bot designed to protect your comm
|
|||||||
|
|
||||||
1. Clone the repository:
|
1. Clone the repository:
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/yourusername/guardden.git
|
git clone https://git.hiddenden.cafe/Hiddenden/GuardDen.git
|
||||||
cd guardden
|
cd guardden
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -155,6 +155,9 @@ GuardDen is a comprehensive Discord moderation bot designed to protect your comm
|
|||||||
| `GUARDDEN_DASHBOARD_OWNER_DISCORD_ID` | Discord user ID allowed | Required |
|
| `GUARDDEN_DASHBOARD_OWNER_DISCORD_ID` | Discord user ID allowed | Required |
|
||||||
| `GUARDDEN_DASHBOARD_OWNER_ENTRA_OBJECT_ID` | Entra object ID allowed | Required |
|
| `GUARDDEN_DASHBOARD_OWNER_ENTRA_OBJECT_ID` | Entra object ID allowed | Required |
|
||||||
| `GUARDDEN_DASHBOARD_CORS_ORIGINS` | Dashboard CORS origins | (empty = none) |
|
| `GUARDDEN_DASHBOARD_CORS_ORIGINS` | Dashboard CORS origins | (empty = none) |
|
||||||
|
| `GUARDDEN_WORDLIST_ENABLED` | Enable managed wordlist sync | `true` |
|
||||||
|
| `GUARDDEN_WORDLIST_UPDATE_HOURS` | Managed wordlist sync interval | `168` |
|
||||||
|
| `GUARDDEN_WORDLIST_SOURCES` | JSON array of wordlist sources | (empty = defaults) |
|
||||||
|
|
||||||
### Per-Guild Settings
|
### Per-Guild Settings
|
||||||
|
|
||||||
@@ -208,6 +211,10 @@ Each server can configure:
|
|||||||
| `!bannedwords add <word> [action] [is_regex]` | Add a banned word |
|
| `!bannedwords add <word> [action] [is_regex]` | Add a banned word |
|
||||||
| `!bannedwords remove <id>` | Remove a banned word by ID |
|
| `!bannedwords remove <id>` | Remove a banned word by ID |
|
||||||
|
|
||||||
|
Managed wordlists are synced weekly by default. You can override sources with
|
||||||
|
`GUARDDEN_WORDLIST_SOURCES` (JSON array) or disable syncing entirely with
|
||||||
|
`GUARDDEN_WORDLIST_ENABLED=false`.
|
||||||
|
|
||||||
### Automod
|
### Automod
|
||||||
|
|
||||||
| Command | Description |
|
| Command | Description |
|
||||||
@@ -262,6 +269,11 @@ The dashboard provides read-only visibility into moderation logs across all serv
|
|||||||
- Entra: `http://localhost:8080/auth/entra/callback`
|
- Entra: `http://localhost:8080/auth/entra/callback`
|
||||||
- Discord: `http://localhost:8080/auth/discord/callback`
|
- Discord: `http://localhost:8080/auth/discord/callback`
|
||||||
|
|
||||||
|
## CI (Gitea Actions)
|
||||||
|
|
||||||
|
Workflows live under `.gitea/workflows/` and mirror the previous GitHub Actions
|
||||||
|
pipeline for linting, tests, and Docker builds.
|
||||||
|
|
||||||
## Project Structure
|
## Project Structure
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|||||||
37
migrations/versions/20260117_add_banned_word_metadata.py
Normal file
37
migrations/versions/20260117_add_banned_word_metadata.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Add metadata fields for managed banned words.
|
||||||
|
|
||||||
|
Revision ID: 20260117_add_banned_word_metadata
|
||||||
|
Revises: 20260117_enable_ai_defaults
|
||||||
|
Create Date: 2026-01-17 21:15:00.000000
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "20260117_add_banned_word_metadata"
|
||||||
|
down_revision = "20260117_enable_ai_defaults"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"banned_words",
|
||||||
|
sa.Column("source", sa.String(length=100), nullable=True),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"banned_words",
|
||||||
|
sa.Column("category", sa.String(length=20), nullable=True),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"banned_words",
|
||||||
|
sa.Column("managed", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||||
|
)
|
||||||
|
op.alter_column("banned_words", "managed", server_default=None)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("banned_words", "managed")
|
||||||
|
op.drop_column("banned_words", "category")
|
||||||
|
op.drop_column("banned_words", "source")
|
||||||
41
migrations/versions/20260117_enable_ai_defaults.py
Normal file
41
migrations/versions/20260117_enable_ai_defaults.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""Enable AI moderation defaults for existing guilds.
|
||||||
|
|
||||||
|
Revision ID: 20260117_enable_ai_defaults
|
||||||
|
Revises: 20260117_analytics
|
||||||
|
Create Date: 2026-01-17 21:00:00.000000
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "20260117_enable_ai_defaults"
|
||||||
|
down_revision = "20260117_analytics"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
sa.text(
|
||||||
|
"""
|
||||||
|
UPDATE guild_settings
|
||||||
|
SET ai_moderation_enabled = TRUE,
|
||||||
|
nsfw_detection_enabled = TRUE,
|
||||||
|
ai_sensitivity = 80
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
sa.text(
|
||||||
|
"""
|
||||||
|
UPDATE guild_settings
|
||||||
|
SET ai_moderation_enabled = FALSE,
|
||||||
|
nsfw_detection_enabled = FALSE,
|
||||||
|
ai_sensitivity = 50
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
@@ -32,6 +32,7 @@ dependencies = [
|
|||||||
"uvicorn>=0.27.0",
|
"uvicorn>=0.27.0",
|
||||||
"authlib>=1.3.0",
|
"authlib>=1.3.0",
|
||||||
"httpx>=0.27.0",
|
"httpx>=0.27.0",
|
||||||
|
"itsdangerous>=2.1.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@@ -39,6 +40,7 @@ dev = [
|
|||||||
"pytest>=7.4.0",
|
"pytest>=7.4.0",
|
||||||
"pytest-asyncio>=0.23.0",
|
"pytest-asyncio>=0.23.0",
|
||||||
"pytest-cov>=4.1.0",
|
"pytest-cov>=4.1.0",
|
||||||
|
"aiosqlite>=0.19.0",
|
||||||
"ruff>=0.1.0",
|
"ruff>=0.1.0",
|
||||||
"mypy>=1.7.0",
|
"mypy>=1.7.0",
|
||||||
"pre-commit>=3.6.0",
|
"pre-commit>=3.6.0",
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ class GuardDen(commands.Bot):
|
|||||||
self.database = Database(settings)
|
self.database = Database(settings)
|
||||||
self.guild_config: "GuildConfigService | None" = None
|
self.guild_config: "GuildConfigService | None" = None
|
||||||
self.ai_provider: AIProvider | None = None
|
self.ai_provider: AIProvider | None = None
|
||||||
|
self.wordlist_service = None
|
||||||
self.rate_limiter = RateLimiter()
|
self.rate_limiter = RateLimiter()
|
||||||
|
|
||||||
async def _get_prefix(self, bot: "GuardDen", message: discord.Message) -> list[str]:
|
async def _get_prefix(self, bot: "GuardDen", message: discord.Message) -> list[str]:
|
||||||
@@ -90,6 +91,9 @@ class GuardDen(commands.Bot):
|
|||||||
from guardden.services.guild_config import GuildConfigService
|
from guardden.services.guild_config import GuildConfigService
|
||||||
|
|
||||||
self.guild_config = GuildConfigService(self.database)
|
self.guild_config = GuildConfigService(self.database)
|
||||||
|
from guardden.services.wordlist import WordlistService
|
||||||
|
|
||||||
|
self.wordlist_service = WordlistService(self.database, self.settings)
|
||||||
|
|
||||||
# Initialize AI provider
|
# Initialize AI provider
|
||||||
api_key = None
|
api_key = None
|
||||||
@@ -115,6 +119,7 @@ class GuardDen(commands.Bot):
|
|||||||
"guardden.cogs.ai_moderation",
|
"guardden.cogs.ai_moderation",
|
||||||
"guardden.cogs.verification",
|
"guardden.cogs.verification",
|
||||||
"guardden.cogs.health",
|
"guardden.cogs.health",
|
||||||
|
"guardden.cogs.wordlist_sync",
|
||||||
]
|
]
|
||||||
|
|
||||||
failed_cogs = []
|
failed_cogs = []
|
||||||
@@ -131,7 +136,7 @@ class GuardDen(commands.Bot):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error loading cog {cog}: {e}", exc_info=True)
|
logger.error(f"Unexpected error loading cog {cog}: {e}", exc_info=True)
|
||||||
failed_cogs.append(cog)
|
failed_cogs.append(cog)
|
||||||
|
|
||||||
if failed_cogs:
|
if failed_cogs:
|
||||||
logger.warning(f"Failed to load {len(failed_cogs)} cog(s): {', '.join(failed_cogs)}")
|
logger.warning(f"Failed to load {len(failed_cogs)} cog(s): {', '.join(failed_cogs)}")
|
||||||
# Don't fail startup if some cogs fail to load, but log it prominently
|
# Don't fail startup if some cogs fail to load, but log it prominently
|
||||||
@@ -146,7 +151,7 @@ class GuardDen(commands.Bot):
|
|||||||
if self.guild_config:
|
if self.guild_config:
|
||||||
initialized = 0
|
initialized = 0
|
||||||
failed_guilds = []
|
failed_guilds = []
|
||||||
|
|
||||||
for guild in self.guilds:
|
for guild in self.guilds:
|
||||||
try:
|
try:
|
||||||
if not self.is_guild_allowed(guild.id):
|
if not self.is_guild_allowed(guild.id):
|
||||||
@@ -162,12 +167,17 @@ class GuardDen(commands.Bot):
|
|||||||
await self.guild_config.create_guild(guild)
|
await self.guild_config.create_guild(guild)
|
||||||
initialized += 1
|
initialized += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize config for guild {guild.id} ({guild.name}): {e}", exc_info=True)
|
logger.error(
|
||||||
|
f"Failed to initialize config for guild {guild.id} ({guild.name}): {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
failed_guilds.append(guild.id)
|
failed_guilds.append(guild.id)
|
||||||
|
|
||||||
logger.info("Initialized config for %s guild(s)", initialized)
|
logger.info("Initialized config for %s guild(s)", initialized)
|
||||||
if failed_guilds:
|
if failed_guilds:
|
||||||
logger.warning(f"Failed to initialize {len(failed_guilds)} guild(s): {failed_guilds}")
|
logger.warning(
|
||||||
|
f"Failed to initialize {len(failed_guilds)} guild(s): {failed_guilds}"
|
||||||
|
)
|
||||||
|
|
||||||
# Set presence
|
# Set presence
|
||||||
activity = discord.Activity(
|
activity = discord.Activity(
|
||||||
@@ -206,9 +216,7 @@ class GuardDen(commands.Bot):
|
|||||||
logger.info(f"Joined guild: {guild.name} (ID: {guild.id})")
|
logger.info(f"Joined guild: {guild.name} (ID: {guild.id})")
|
||||||
|
|
||||||
if not self.is_guild_allowed(guild.id):
|
if not self.is_guild_allowed(guild.id):
|
||||||
logger.warning(
|
logger.warning("Guild %s (ID: %s) not in allowlist, leaving.", guild.name, guild.id)
|
||||||
"Guild %s (ID: %s) not in allowlist, leaving.", guild.name, guild.id
|
|
||||||
)
|
|
||||||
await guild.leave()
|
await guild.leave()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
38
src/guardden/cogs/wordlist_sync.py
Normal file
38
src/guardden/cogs/wordlist_sync.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""Background task for managed wordlist syncing."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from discord.ext import commands, tasks
|
||||||
|
|
||||||
|
from guardden.services.wordlist import WordlistService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WordlistSync(commands.Cog):
|
||||||
|
"""Periodic sync of managed wordlists into guild bans."""
|
||||||
|
|
||||||
|
def __init__(self, bot: commands.Bot, service: WordlistService) -> None:
|
||||||
|
self.bot = bot
|
||||||
|
self.service = service
|
||||||
|
self.sync_task.change_interval(hours=service.update_interval.total_seconds() / 3600)
|
||||||
|
self.sync_task.start()
|
||||||
|
|
||||||
|
def cog_unload(self) -> None:
|
||||||
|
self.sync_task.cancel()
|
||||||
|
|
||||||
|
@tasks.loop(hours=1)
|
||||||
|
async def sync_task(self) -> None:
|
||||||
|
await self.service.sync_all()
|
||||||
|
|
||||||
|
@sync_task.before_loop
|
||||||
|
async def before_sync_task(self) -> None:
|
||||||
|
await self.bot.wait_until_ready()
|
||||||
|
|
||||||
|
|
||||||
|
async def setup(bot: commands.Bot) -> None:
|
||||||
|
service = getattr(bot, "wordlist_service", None)
|
||||||
|
if not service:
|
||||||
|
logger.warning("Wordlist service not initialized; skipping sync task")
|
||||||
|
return
|
||||||
|
await bot.add_cog(WordlistSync(bot, service))
|
||||||
@@ -5,9 +5,9 @@ import re
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import Field, SecretStr, field_validator, ValidationError
|
from pydantic import BaseModel, Field, SecretStr, ValidationError, field_validator
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
from pydantic_settings.sources import EnvSettingsSource
|
||||||
|
|
||||||
# Discord snowflake ID validation regex (64-bit integers, 17-19 digits)
|
# Discord snowflake ID validation regex (64-bit integers, 17-19 digits)
|
||||||
DISCORD_ID_PATTERN = re.compile(r"^\d{17,19}$")
|
DISCORD_ID_PATTERN = re.compile(r"^\d{17,19}$")
|
||||||
@@ -19,17 +19,17 @@ def _validate_discord_id(value: str | int) -> int:
|
|||||||
id_str = str(value)
|
id_str = str(value)
|
||||||
else:
|
else:
|
||||||
id_str = str(value).strip()
|
id_str = str(value).strip()
|
||||||
|
|
||||||
# Check format
|
# Check format
|
||||||
if not DISCORD_ID_PATTERN.match(id_str):
|
if not DISCORD_ID_PATTERN.match(id_str):
|
||||||
raise ValueError(f"Invalid Discord ID format: {id_str}")
|
raise ValueError(f"Invalid Discord ID format: {id_str}")
|
||||||
|
|
||||||
# Convert to int and validate range
|
# Convert to int and validate range
|
||||||
discord_id = int(id_str)
|
discord_id = int(id_str)
|
||||||
# Discord snowflakes are 64-bit integers, minimum valid ID is around 2010
|
# Discord snowflakes are 64-bit integers, minimum valid ID is around 2010
|
||||||
if discord_id < 100000000000000000 or discord_id > 9999999999999999999:
|
if discord_id < 100000000000000000 or discord_id > 9999999999999999999:
|
||||||
raise ValueError(f"Discord ID out of valid range: {discord_id}")
|
raise ValueError(f"Discord ID out of valid range: {discord_id}")
|
||||||
|
|
||||||
return discord_id
|
return discord_id
|
||||||
|
|
||||||
|
|
||||||
@@ -65,6 +65,27 @@ def _parse_id_list(value: Any) -> list[int]:
|
|||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
class GuardDenEnvSettingsSource(EnvSettingsSource):
|
||||||
|
"""Environment settings source with safe list parsing."""
|
||||||
|
|
||||||
|
def decode_complex_value(self, field_name: str, field, value: Any):
|
||||||
|
if field_name in {"allowed_guilds", "owner_ids"} and isinstance(value, str):
|
||||||
|
return value
|
||||||
|
return super().decode_complex_value(field_name, field, value)
|
||||||
|
|
||||||
|
|
||||||
|
class WordlistSourceConfig(BaseModel):
|
||||||
|
"""Configuration for a managed wordlist source."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
url: str
|
||||||
|
category: Literal["hard", "soft", "context"]
|
||||||
|
action: Literal["delete", "warn", "strike"]
|
||||||
|
reason: str
|
||||||
|
is_regex: bool = False
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
"""Application settings loaded from environment variables."""
|
"""Application settings loaded from environment variables."""
|
||||||
|
|
||||||
@@ -73,8 +94,25 @@ class Settings(BaseSettings):
|
|||||||
env_file_encoding="utf-8",
|
env_file_encoding="utf-8",
|
||||||
case_sensitive=False,
|
case_sensitive=False,
|
||||||
env_prefix="GUARDDEN_",
|
env_prefix="GUARDDEN_",
|
||||||
|
env_parse_none_str="",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def settings_customise_sources(
|
||||||
|
cls,
|
||||||
|
settings_cls,
|
||||||
|
init_settings,
|
||||||
|
env_settings,
|
||||||
|
dotenv_settings,
|
||||||
|
file_secret_settings,
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
init_settings,
|
||||||
|
GuardDenEnvSettingsSource(settings_cls),
|
||||||
|
dotenv_settings,
|
||||||
|
file_secret_settings,
|
||||||
|
)
|
||||||
|
|
||||||
# Discord settings
|
# Discord settings
|
||||||
discord_token: SecretStr = Field(..., description="Discord bot token")
|
discord_token: SecretStr = Field(..., description="Discord bot token")
|
||||||
discord_prefix: str = Field(default="!", description="Default command prefix")
|
discord_prefix: str = Field(default="!", description="Default command prefix")
|
||||||
@@ -114,11 +152,43 @@ class Settings(BaseSettings):
|
|||||||
# Paths
|
# Paths
|
||||||
data_dir: Path = Field(default=Path("data"), description="Data directory for persistent files")
|
data_dir: Path = Field(default=Path("data"), description="Data directory for persistent files")
|
||||||
|
|
||||||
|
# Wordlist sync
|
||||||
|
wordlist_enabled: bool = Field(
|
||||||
|
default=True, description="Enable automatic managed wordlist syncing"
|
||||||
|
)
|
||||||
|
wordlist_update_hours: int = Field(
|
||||||
|
default=168, description="Managed wordlist sync interval in hours"
|
||||||
|
)
|
||||||
|
wordlist_sources: list[WordlistSourceConfig] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Managed wordlist sources (JSON array via env overrides)",
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("allowed_guilds", "owner_ids", mode="before")
|
@field_validator("allowed_guilds", "owner_ids", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_id_list(cls, value: Any) -> list[int]:
|
def _validate_id_list(cls, value: Any) -> list[int]:
|
||||||
return _parse_id_list(value)
|
return _parse_id_list(value)
|
||||||
|
|
||||||
|
@field_validator("wordlist_sources", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _parse_wordlist_sources(cls, value: Any) -> list[WordlistSourceConfig]:
|
||||||
|
if value is None:
|
||||||
|
return []
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [WordlistSourceConfig.model_validate(item) for item in value]
|
||||||
|
if isinstance(value, str):
|
||||||
|
text = value.strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
data = json.loads(text)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise ValueError("Invalid JSON for wordlist_sources") from exc
|
||||||
|
if not isinstance(data, list):
|
||||||
|
raise ValueError("wordlist_sources must be a JSON array")
|
||||||
|
return [WordlistSourceConfig.model_validate(item) for item in data]
|
||||||
|
return []
|
||||||
|
|
||||||
@field_validator("discord_token")
|
@field_validator("discord_token")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_discord_token(cls, value: SecretStr) -> SecretStr:
|
def _validate_discord_token(cls, value: SecretStr) -> SecretStr:
|
||||||
@@ -126,11 +196,11 @@ class Settings(BaseSettings):
|
|||||||
token = value.get_secret_value()
|
token = value.get_secret_value()
|
||||||
if not token:
|
if not token:
|
||||||
raise ValueError("Discord token cannot be empty")
|
raise ValueError("Discord token cannot be empty")
|
||||||
|
|
||||||
# Basic Discord token format validation (not perfect but catches common issues)
|
# Basic Discord token format validation (not perfect but catches common issues)
|
||||||
if len(token) < 50 or not re.match(r"^[A-Za-z0-9._-]+$", token):
|
if len(token) < 50 or not re.match(r"^[A-Za-z0-9._-]+$", token):
|
||||||
raise ValueError("Invalid Discord token format")
|
raise ValueError("Invalid Discord token format")
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@field_validator("anthropic_api_key", "openai_api_key")
|
@field_validator("anthropic_api_key", "openai_api_key")
|
||||||
@@ -139,15 +209,15 @@ class Settings(BaseSettings):
|
|||||||
"""Validate API key format if provided."""
|
"""Validate API key format if provided."""
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
key = value.get_secret_value()
|
key = value.get_secret_value()
|
||||||
if not key:
|
if not key:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Basic API key validation
|
# Basic API key validation
|
||||||
if len(key) < 20:
|
if len(key) < 20:
|
||||||
raise ValueError("API key too short to be valid")
|
raise ValueError("API key too short to be valid")
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def validate_configuration(self) -> None:
|
def validate_configuration(self) -> None:
|
||||||
@@ -157,17 +227,21 @@ class Settings(BaseSettings):
|
|||||||
raise ValueError("GUARDDEN_ANTHROPIC_API_KEY is required when AI provider is anthropic")
|
raise ValueError("GUARDDEN_ANTHROPIC_API_KEY is required when AI provider is anthropic")
|
||||||
if self.ai_provider == "openai" and not self.openai_api_key:
|
if self.ai_provider == "openai" and not self.openai_api_key:
|
||||||
raise ValueError("GUARDDEN_OPENAI_API_KEY is required when AI provider is openai")
|
raise ValueError("GUARDDEN_OPENAI_API_KEY is required when AI provider is openai")
|
||||||
|
|
||||||
# Database pool validation
|
# Database pool validation
|
||||||
if self.database_pool_min > self.database_pool_max:
|
if self.database_pool_min > self.database_pool_max:
|
||||||
raise ValueError("database_pool_min cannot be greater than database_pool_max")
|
raise ValueError("database_pool_min cannot be greater than database_pool_max")
|
||||||
if self.database_pool_min < 1:
|
if self.database_pool_min < 1:
|
||||||
raise ValueError("database_pool_min must be at least 1")
|
raise ValueError("database_pool_min must be at least 1")
|
||||||
|
|
||||||
# Data directory validation
|
# Data directory validation
|
||||||
if not isinstance(self.data_dir, Path):
|
if not isinstance(self.data_dir, Path):
|
||||||
raise ValueError("data_dir must be a valid path")
|
raise ValueError("data_dir must be a valid path")
|
||||||
|
|
||||||
|
# Wordlist validation
|
||||||
|
if self.wordlist_update_hours < 1:
|
||||||
|
raise ValueError("wordlist_update_hours must be at least 1")
|
||||||
|
|
||||||
|
|
||||||
def get_settings() -> Settings:
|
def get_settings() -> Settings:
|
||||||
"""Get application settings instance."""
|
"""Get application settings instance."""
|
||||||
|
|||||||
16
src/guardden/dashboard/__main__.py
Normal file
16
src/guardden/dashboard/__main__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Dashboard entrypoint for `python -m guardden.dashboard`."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
host = os.getenv("GUARDDEN_DASHBOARD_HOST", "0.0.0.0")
|
||||||
|
port = int(os.getenv("GUARDDEN_DASHBOARD_PORT", "8000"))
|
||||||
|
log_level = os.getenv("GUARDDEN_LOG_LEVEL", "info").lower()
|
||||||
|
uvicorn.run("guardden.dashboard.main:app", host=host, port=port, log_level=log_level)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy import Boolean, Float, ForeignKey, Integer, String, Text
|
from sqlalchemy import JSON, Boolean, Float, ForeignKey, Integer, String, Text
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
@@ -59,7 +59,9 @@ class GuildSettings(Base, TimestampMixin):
|
|||||||
# Role configuration
|
# Role configuration
|
||||||
mute_role_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
mute_role_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||||
verified_role_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
verified_role_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||||
mod_role_ids: Mapped[dict] = mapped_column(JSONB, default=list, nullable=False)
|
mod_role_ids: Mapped[dict] = mapped_column(
|
||||||
|
JSONB().with_variant(JSON(), "sqlite"), default=list, nullable=False
|
||||||
|
)
|
||||||
|
|
||||||
# Moderation settings
|
# Moderation settings
|
||||||
automod_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
automod_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||||
@@ -73,11 +75,13 @@ class GuildSettings(Base, TimestampMixin):
|
|||||||
mention_limit: Mapped[int] = mapped_column(Integer, default=5, nullable=False)
|
mention_limit: Mapped[int] = mapped_column(Integer, default=5, nullable=False)
|
||||||
mention_rate_limit: Mapped[int] = mapped_column(Integer, default=10, nullable=False)
|
mention_rate_limit: Mapped[int] = mapped_column(Integer, default=10, nullable=False)
|
||||||
mention_rate_window: Mapped[int] = mapped_column(Integer, default=60, nullable=False)
|
mention_rate_window: Mapped[int] = mapped_column(Integer, default=60, nullable=False)
|
||||||
scam_allowlist: Mapped[list[str]] = mapped_column(JSONB, default=list, nullable=False)
|
scam_allowlist: Mapped[list[str]] = mapped_column(
|
||||||
|
JSONB().with_variant(JSON(), "sqlite"), default=list, nullable=False
|
||||||
|
)
|
||||||
|
|
||||||
# Strike thresholds (actions at each threshold)
|
# Strike thresholds (actions at each threshold)
|
||||||
strike_actions: Mapped[dict] = mapped_column(
|
strike_actions: Mapped[dict] = mapped_column(
|
||||||
JSONB,
|
JSONB().with_variant(JSON(), "sqlite"),
|
||||||
default=lambda: {
|
default=lambda: {
|
||||||
"1": {"action": "warn"},
|
"1": {"action": "warn"},
|
||||||
"3": {"action": "timeout", "duration": 3600},
|
"3": {"action": "timeout", "duration": 3600},
|
||||||
@@ -88,11 +92,11 @@ class GuildSettings(Base, TimestampMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# AI moderation settings
|
# AI moderation settings
|
||||||
ai_moderation_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
ai_moderation_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||||
ai_sensitivity: Mapped[int] = mapped_column(Integer, default=50, nullable=False) # 0-100 scale
|
ai_sensitivity: Mapped[int] = mapped_column(Integer, default=80, nullable=False) # 0-100 scale
|
||||||
ai_confidence_threshold: Mapped[float] = mapped_column(Float, default=0.7, nullable=False)
|
ai_confidence_threshold: Mapped[float] = mapped_column(Float, default=0.7, nullable=False)
|
||||||
ai_log_only: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
ai_log_only: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||||
nsfw_detection_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
nsfw_detection_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||||
|
|
||||||
# Verification settings
|
# Verification settings
|
||||||
verification_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
verification_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||||
@@ -120,6 +124,9 @@ class BannedWord(Base, TimestampMixin):
|
|||||||
String(20), default="delete", nullable=False
|
String(20), default="delete", nullable=False
|
||||||
) # delete, warn, strike
|
) # delete, warn, strike
|
||||||
reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
source: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
|
category: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||||
|
managed: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||||
|
|
||||||
# Who added this and when
|
# Who added this and when
|
||||||
added_by: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
added_by: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import time
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import NamedTuple, Sequence, TYPE_CHECKING
|
from typing import TYPE_CHECKING, NamedTuple, Sequence
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -16,6 +16,7 @@ else:
|
|||||||
try:
|
try:
|
||||||
import discord # type: ignore
|
import discord # type: ignore
|
||||||
except ModuleNotFoundError: # pragma: no cover
|
except ModuleNotFoundError: # pragma: no cover
|
||||||
|
|
||||||
class _DiscordStub:
|
class _DiscordStub:
|
||||||
class Message: # minimal stub for type hints
|
class Message: # minimal stub for type hints
|
||||||
pass
|
pass
|
||||||
@@ -26,120 +27,122 @@ from guardden.models.guild import BannedWord
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Circuit breaker for regex safety
|
# Circuit breaker for regex safety
|
||||||
class RegexTimeoutError(Exception):
|
class RegexTimeoutError(Exception):
|
||||||
"""Raised when regex execution takes too long."""
|
"""Raised when regex execution takes too long."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RegexCircuitBreaker:
|
class RegexCircuitBreaker:
|
||||||
"""Circuit breaker to prevent catastrophic backtracking in regex patterns."""
|
"""Circuit breaker to prevent catastrophic backtracking in regex patterns."""
|
||||||
|
|
||||||
def __init__(self, timeout_seconds: float = 0.1):
|
def __init__(self, timeout_seconds: float = 0.1):
|
||||||
self.timeout_seconds = timeout_seconds
|
self.timeout_seconds = timeout_seconds
|
||||||
self.failed_patterns: dict[str, datetime] = {}
|
self.failed_patterns: dict[str, datetime] = {}
|
||||||
self.failure_threshold = timedelta(minutes=5) # Disable pattern for 5 minutes after failure
|
self.failure_threshold = timedelta(minutes=5) # Disable pattern for 5 minutes after failure
|
||||||
|
|
||||||
def _timeout_handler(self, signum, frame):
|
def _timeout_handler(self, signum, frame):
|
||||||
"""Signal handler for regex timeout."""
|
"""Signal handler for regex timeout."""
|
||||||
raise RegexTimeoutError("Regex execution timed out")
|
raise RegexTimeoutError("Regex execution timed out")
|
||||||
|
|
||||||
def is_pattern_disabled(self, pattern: str) -> bool:
|
def is_pattern_disabled(self, pattern: str) -> bool:
|
||||||
"""Check if a pattern is temporarily disabled due to timeouts."""
|
"""Check if a pattern is temporarily disabled due to timeouts."""
|
||||||
if pattern not in self.failed_patterns:
|
if pattern not in self.failed_patterns:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
failure_time = self.failed_patterns[pattern]
|
failure_time = self.failed_patterns[pattern]
|
||||||
if datetime.now(timezone.utc) - failure_time > self.failure_threshold:
|
if datetime.now(timezone.utc) - failure_time > self.failure_threshold:
|
||||||
# Re-enable the pattern after threshold time
|
# Re-enable the pattern after threshold time
|
||||||
del self.failed_patterns[pattern]
|
del self.failed_patterns[pattern]
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def safe_regex_search(self, pattern: str, text: str, flags: int = 0) -> bool:
|
def safe_regex_search(self, pattern: str, text: str, flags: int = 0) -> bool:
|
||||||
"""Safely execute regex search with timeout protection."""
|
"""Safely execute regex search with timeout protection."""
|
||||||
if self.is_pattern_disabled(pattern):
|
if self.is_pattern_disabled(pattern):
|
||||||
logger.warning(f"Regex pattern temporarily disabled due to timeout: {pattern[:50]}...")
|
logger.warning(f"Regex pattern temporarily disabled due to timeout: {pattern[:50]}...")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Basic pattern validation to catch obviously problematic patterns
|
# Basic pattern validation to catch obviously problematic patterns
|
||||||
if self._is_dangerous_pattern(pattern):
|
if self._is_dangerous_pattern(pattern):
|
||||||
logger.warning(f"Potentially dangerous regex pattern rejected: {pattern[:50]}...")
|
logger.warning(f"Potentially dangerous regex pattern rejected: {pattern[:50]}...")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
old_handler = None
|
old_handler = None
|
||||||
try:
|
try:
|
||||||
# Set up timeout signal (Unix systems only)
|
# Set up timeout signal (Unix systems only)
|
||||||
if hasattr(signal, 'SIGALRM'):
|
if hasattr(signal, "SIGALRM"):
|
||||||
old_handler = signal.signal(signal.SIGALRM, self._timeout_handler)
|
old_handler = signal.signal(signal.SIGALRM, self._timeout_handler)
|
||||||
signal.alarm(int(self.timeout_seconds * 1000)) # Convert to milliseconds
|
signal.alarm(int(self.timeout_seconds * 1000)) # Convert to milliseconds
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
# Compile and execute regex
|
# Compile and execute regex
|
||||||
compiled_pattern = re.compile(pattern, flags)
|
compiled_pattern = re.compile(pattern, flags)
|
||||||
result = bool(compiled_pattern.search(text))
|
result = bool(compiled_pattern.search(text))
|
||||||
|
|
||||||
execution_time = time.perf_counter() - start_time
|
execution_time = time.perf_counter() - start_time
|
||||||
|
|
||||||
# Log slow patterns for monitoring
|
# Log slow patterns for monitoring
|
||||||
if execution_time > self.timeout_seconds * 0.8:
|
if execution_time > self.timeout_seconds * 0.8:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Slow regex pattern (took {execution_time:.3f}s): {pattern[:50]}..."
|
f"Slow regex pattern (took {execution_time:.3f}s): {pattern[:50]}..."
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except RegexTimeoutError:
|
except RegexTimeoutError:
|
||||||
# Pattern took too long, disable it temporarily
|
# Pattern took too long, disable it temporarily
|
||||||
self.failed_patterns[pattern] = datetime.now(timezone.utc)
|
self.failed_patterns[pattern] = datetime.now(timezone.utc)
|
||||||
logger.error(f"Regex pattern timed out and disabled: {pattern[:50]}...")
|
logger.error(f"Regex pattern timed out and disabled: {pattern[:50]}...")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except re.error as e:
|
except re.error as e:
|
||||||
logger.warning(f"Invalid regex pattern '{pattern[:50]}...': {e}")
|
logger.warning(f"Invalid regex pattern '{pattern[:50]}...': {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error in regex execution: {e}")
|
logger.error(f"Unexpected error in regex execution: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up timeout signal
|
# Clean up timeout signal
|
||||||
if hasattr(signal, 'SIGALRM') and old_handler is not None:
|
if hasattr(signal, "SIGALRM") and old_handler is not None:
|
||||||
signal.alarm(0)
|
signal.alarm(0)
|
||||||
signal.signal(signal.SIGALRM, old_handler)
|
signal.signal(signal.SIGALRM, old_handler)
|
||||||
|
|
||||||
def _is_dangerous_pattern(self, pattern: str) -> bool:
|
def _is_dangerous_pattern(self, pattern: str) -> bool:
|
||||||
"""Basic heuristic to detect potentially dangerous regex patterns."""
|
"""Basic heuristic to detect potentially dangerous regex patterns."""
|
||||||
# Check for patterns that are commonly problematic
|
# Check for patterns that are commonly problematic
|
||||||
dangerous_indicators = [
|
dangerous_indicators = [
|
||||||
r'(\w+)+', # Nested quantifiers
|
r"(\w+)+", # Nested quantifiers
|
||||||
r'(\d+)+', # Nested quantifiers on digits
|
r"(\d+)+", # Nested quantifiers on digits
|
||||||
r'(.+)+', # Nested quantifiers on anything
|
r"(.+)+", # Nested quantifiers on anything
|
||||||
r'(.*)+', # Nested quantifiers on anything (greedy)
|
r"(.*)+", # Nested quantifiers on anything (greedy)
|
||||||
r'(\w*)+', # Nested quantifiers with *
|
r"(\w*)+", # Nested quantifiers with *
|
||||||
r'(\S+)+', # Nested quantifiers on non-whitespace
|
r"(\S+)+", # Nested quantifiers on non-whitespace
|
||||||
]
|
]
|
||||||
|
|
||||||
# Check for excessively long patterns
|
# Check for excessively long patterns
|
||||||
if len(pattern) > 500:
|
if len(pattern) > 500:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check for nested quantifiers (simplified detection)
|
# Check for nested quantifiers (simplified detection)
|
||||||
if '+)+' in pattern or '*)+' in pattern or '?)+' in pattern:
|
if "+)+" in pattern or "*)+" in pattern or "?)+" in pattern:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check for excessive repetition operators
|
# Check for excessive repetition operators
|
||||||
if pattern.count('+') > 10 or pattern.count('*') > 10:
|
if pattern.count("+") > 10 or pattern.count("*") > 10:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check for specific dangerous patterns
|
# Check for specific dangerous patterns
|
||||||
for dangerous in dangerous_indicators:
|
for dangerous in dangerous_indicators:
|
||||||
if dangerous in pattern:
|
if dangerous in pattern:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@@ -240,34 +243,43 @@ def normalize_domain(value: str) -> str:
|
|||||||
"""Normalize a domain or URL for allowlist checks with security validation."""
|
"""Normalize a domain or URL for allowlist checks with security validation."""
|
||||||
if not value or not isinstance(value, str):
|
if not value or not isinstance(value, str):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
if any(char in value for char in ["\x00", "\n", "\r", "\t"]):
|
||||||
|
return ""
|
||||||
|
|
||||||
text = value.strip().lower()
|
text = value.strip().lower()
|
||||||
if not text or len(text) > 2000: # Prevent excessively long URLs
|
if not text or len(text) > 2000: # Prevent excessively long URLs
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Sanitize input to prevent injection attacks
|
|
||||||
if any(char in text for char in ['\x00', '\n', '\r', '\t']):
|
|
||||||
return ""
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if "://" not in text:
|
if "://" not in text:
|
||||||
text = f"http://{text}"
|
text = f"http://{text}"
|
||||||
|
|
||||||
parsed = urlparse(text)
|
parsed = urlparse(text)
|
||||||
hostname = parsed.hostname or ""
|
hostname = parsed.hostname or ""
|
||||||
|
|
||||||
# Additional validation for hostname
|
# Additional validation for hostname
|
||||||
if not hostname or len(hostname) > 253: # RFC limit
|
if not hostname or len(hostname) > 253: # RFC limit
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Check for malicious patterns
|
# Check for malicious patterns
|
||||||
if any(char in hostname for char in [' ', '\x00', '\n', '\r', '\t']):
|
if any(char in hostname for char in [" ", "\x00", "\n", "\r", "\t"]):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
if not re.fullmatch(r"[a-z0-9.-]+", hostname):
|
||||||
|
return ""
|
||||||
|
if hostname.startswith(".") or hostname.endswith(".") or ".." in hostname:
|
||||||
|
return ""
|
||||||
|
for label in hostname.split("."):
|
||||||
|
if not label:
|
||||||
|
return ""
|
||||||
|
if label.startswith("-") or label.endswith("-"):
|
||||||
|
return ""
|
||||||
|
|
||||||
# Remove www prefix
|
# Remove www prefix
|
||||||
if hostname.startswith("www."):
|
if hostname.startswith("www."):
|
||||||
hostname = hostname[4:]
|
hostname = hostname[4:]
|
||||||
|
|
||||||
return hostname
|
return hostname
|
||||||
except (ValueError, UnicodeError, Exception):
|
except (ValueError, UnicodeError, Exception):
|
||||||
# urlparse can raise various exceptions with malicious input
|
# urlparse can raise various exceptions with malicious input
|
||||||
@@ -305,13 +317,13 @@ class AutomodService:
|
|||||||
# Normalize: lowercase, remove extra spaces, remove special chars
|
# Normalize: lowercase, remove extra spaces, remove special chars
|
||||||
# Use simple string operations for basic patterns to avoid regex overhead
|
# Use simple string operations for basic patterns to avoid regex overhead
|
||||||
normalized = content.lower()
|
normalized = content.lower()
|
||||||
|
|
||||||
# Remove special characters (simplified approach)
|
# Remove special characters (simplified approach)
|
||||||
normalized = ''.join(c for c in normalized if c.isalnum() or c.isspace())
|
normalized = "".join(c for c in normalized if c.isalnum() or c.isspace())
|
||||||
|
|
||||||
# Normalize whitespace
|
# Normalize whitespace
|
||||||
normalized = ' '.join(normalized.split())
|
normalized = " ".join(normalized.split())
|
||||||
|
|
||||||
return normalized
|
return normalized
|
||||||
|
|
||||||
def check_banned_words(
|
def check_banned_words(
|
||||||
@@ -369,14 +381,14 @@ class AutomodService:
|
|||||||
# Limit URL length to prevent processing extremely long URLs
|
# Limit URL length to prevent processing extremely long URLs
|
||||||
if len(url) > 2000:
|
if len(url) > 2000:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
url_lower = url.lower()
|
url_lower = url.lower()
|
||||||
hostname = normalize_domain(url)
|
hostname = normalize_domain(url)
|
||||||
|
|
||||||
# Skip if hostname normalization failed (security check)
|
# Skip if hostname normalization failed (security check)
|
||||||
if not hostname:
|
if not hostname:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if allowlist_set and is_allowed_domain(hostname, allowlist_set):
|
if allowlist_set and is_allowed_domain(hostname, allowlist_set):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -540,3 +552,11 @@ class AutomodService:
|
|||||||
def cleanup_guild(self, guild_id: int) -> None:
|
def cleanup_guild(self, guild_id: int) -> None:
|
||||||
"""Remove all tracking data for a guild."""
|
"""Remove all tracking data for a guild."""
|
||||||
self._spam_trackers.pop(guild_id, None)
|
self._spam_trackers.pop(guild_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
_automod_service = AutomodService()
|
||||||
|
|
||||||
|
|
||||||
|
def detect_scam_links(content: str, allowlist: list[str] | None = None) -> AutomodResult | None:
|
||||||
|
"""Convenience wrapper for scam detection."""
|
||||||
|
return _automod_service.check_scam_links(content, allowlist)
|
||||||
|
|||||||
@@ -141,6 +141,9 @@ class GuildConfigService:
|
|||||||
is_regex: bool = False,
|
is_regex: bool = False,
|
||||||
action: str = "delete",
|
action: str = "delete",
|
||||||
reason: str | None = None,
|
reason: str | None = None,
|
||||||
|
source: str | None = None,
|
||||||
|
category: str | None = None,
|
||||||
|
managed: bool = False,
|
||||||
) -> BannedWord:
|
) -> BannedWord:
|
||||||
"""Add a banned word to a guild."""
|
"""Add a banned word to a guild."""
|
||||||
async with self.database.session() as session:
|
async with self.database.session() as session:
|
||||||
@@ -150,6 +153,9 @@ class GuildConfigService:
|
|||||||
is_regex=is_regex,
|
is_regex=is_regex,
|
||||||
action=action,
|
action=action,
|
||||||
reason=reason,
|
reason=reason,
|
||||||
|
source=source,
|
||||||
|
category=category,
|
||||||
|
managed=managed,
|
||||||
added_by=added_by,
|
added_by=added_by,
|
||||||
)
|
)
|
||||||
session.add(banned_word)
|
session.add(banned_word)
|
||||||
|
|||||||
180
src/guardden/services/wordlist.py
Normal file
180
src/guardden/services/wordlist.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
"""Managed wordlist sync service."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from sqlalchemy import delete, select
|
||||||
|
|
||||||
|
from guardden.config import Settings, WordlistSourceConfig
|
||||||
|
from guardden.models import BannedWord, Guild
|
||||||
|
from guardden.services.database import Database
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_WORDLIST_ENTRY_LENGTH = 128
|
||||||
|
REQUEST_TIMEOUT = 20.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class WordlistSource:
|
||||||
|
name: str
|
||||||
|
url: str
|
||||||
|
category: str
|
||||||
|
action: str
|
||||||
|
reason: str
|
||||||
|
is_regex: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_SOURCES: list[WordlistSource] = [
|
||||||
|
WordlistSource(
|
||||||
|
name="ldnoobw_en",
|
||||||
|
url="https://raw.githubusercontent.com/LDNOOBW/List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words/master/en",
|
||||||
|
category="soft",
|
||||||
|
action="warn",
|
||||||
|
reason="Auto list: profanity",
|
||||||
|
is_regex=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_entry(line: str) -> str:
|
||||||
|
text = line.strip().lower()
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
if len(text) > MAX_WORDLIST_ENTRY_LENGTH:
|
||||||
|
return ""
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_wordlist(text: str) -> list[str]:
|
||||||
|
entries: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for raw in text.splitlines():
|
||||||
|
line = raw.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
if line.startswith("#") or line.startswith("//") or line.startswith(";"):
|
||||||
|
continue
|
||||||
|
normalized = _normalize_entry(line)
|
||||||
|
if not normalized or normalized in seen:
|
||||||
|
continue
|
||||||
|
entries.append(normalized)
|
||||||
|
seen.add(normalized)
|
||||||
|
return entries
|
||||||
|
|
||||||
|
|
||||||
|
class WordlistService:
|
||||||
|
"""Fetches and syncs managed wordlists into per-guild bans."""
|
||||||
|
|
||||||
|
def __init__(self, database: Database, settings: Settings) -> None:
|
||||||
|
self.database = database
|
||||||
|
self.settings = settings
|
||||||
|
self.sources = self._load_sources(settings)
|
||||||
|
self.update_interval = timedelta(hours=settings.wordlist_update_hours)
|
||||||
|
self.last_sync: datetime | None = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_sources(settings: Settings) -> list[WordlistSource]:
|
||||||
|
if settings.wordlist_sources:
|
||||||
|
sources: list[WordlistSource] = []
|
||||||
|
for src in settings.wordlist_sources:
|
||||||
|
if not src.enabled:
|
||||||
|
continue
|
||||||
|
sources.append(
|
||||||
|
WordlistSource(
|
||||||
|
name=src.name,
|
||||||
|
url=src.url,
|
||||||
|
category=src.category,
|
||||||
|
action=src.action,
|
||||||
|
reason=src.reason,
|
||||||
|
is_regex=src.is_regex,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return sources
|
||||||
|
return list(DEFAULT_SOURCES)
|
||||||
|
|
||||||
|
async def _fetch_source(self, source: WordlistSource) -> list[str]:
|
||||||
|
async with httpx.AsyncClient(timeout=REQUEST_TIMEOUT) as client:
|
||||||
|
response = await client.get(source.url)
|
||||||
|
response.raise_for_status()
|
||||||
|
return _parse_wordlist(response.text)
|
||||||
|
|
||||||
|
async def sync_all(self) -> None:
|
||||||
|
if not self.settings.wordlist_enabled:
|
||||||
|
logger.info("Managed wordlist sync disabled")
|
||||||
|
return
|
||||||
|
if not self.sources:
|
||||||
|
logger.warning("No wordlist sources configured")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Starting managed wordlist sync (%d sources)", len(self.sources))
|
||||||
|
async with self.database.session() as session:
|
||||||
|
guild_ids = list((await session.execute(select(Guild.id))).scalars().all())
|
||||||
|
|
||||||
|
for source in self.sources:
|
||||||
|
try:
|
||||||
|
entries = await self._fetch_source(source)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to fetch wordlist %s: %s", source.name, exc)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not entries:
|
||||||
|
logger.warning("Wordlist %s returned no entries", source.name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
await self._sync_source_to_guilds(source, entries, guild_ids)
|
||||||
|
|
||||||
|
self.last_sync = datetime.now(timezone.utc)
|
||||||
|
logger.info("Managed wordlist sync completed")
|
||||||
|
|
||||||
|
async def _sync_source_to_guilds(
|
||||||
|
self, source: WordlistSource, entries: Iterable[str], guild_ids: list[int]
|
||||||
|
) -> None:
|
||||||
|
entry_set = set(entries)
|
||||||
|
async with self.database.session() as session:
|
||||||
|
for guild_id in guild_ids:
|
||||||
|
result = await session.execute(
|
||||||
|
select(BannedWord).where(
|
||||||
|
BannedWord.guild_id == guild_id,
|
||||||
|
BannedWord.managed.is_(True),
|
||||||
|
BannedWord.source == source.name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing = list(result.scalars().all())
|
||||||
|
existing_set = {word.pattern.lower() for word in existing}
|
||||||
|
|
||||||
|
to_add = entry_set - existing_set
|
||||||
|
to_remove = existing_set - entry_set
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
await session.execute(
|
||||||
|
delete(BannedWord).where(
|
||||||
|
BannedWord.guild_id == guild_id,
|
||||||
|
BannedWord.managed.is_(True),
|
||||||
|
BannedWord.source == source.name,
|
||||||
|
BannedWord.pattern.in_(to_remove),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if to_add:
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
BannedWord(
|
||||||
|
guild_id=guild_id,
|
||||||
|
pattern=pattern,
|
||||||
|
is_regex=source.is_regex,
|
||||||
|
action=source.action,
|
||||||
|
reason=source.reason,
|
||||||
|
source=source.name,
|
||||||
|
category=source.category,
|
||||||
|
managed=True,
|
||||||
|
added_by=0,
|
||||||
|
)
|
||||||
|
for pattern in to_add
|
||||||
|
]
|
||||||
|
)
|
||||||
@@ -7,11 +7,11 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine, event, text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.pool import StaticPool
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
@@ -23,7 +23,7 @@ if str(SRC_DIR) not in sys.path:
|
|||||||
# Import after path setup
|
# Import after path setup
|
||||||
from guardden.config import Settings
|
from guardden.config import Settings
|
||||||
from guardden.models.base import Base
|
from guardden.models.base import Base
|
||||||
from guardden.models.guild import Guild, GuildSettings, BannedWord
|
from guardden.models.guild import BannedWord, Guild, GuildSettings
|
||||||
from guardden.models.moderation import ModerationLog, Strike, UserNote
|
from guardden.models.moderation import ModerationLog, Strike, UserNote
|
||||||
from guardden.services.database import Database
|
from guardden.services.database import Database
|
||||||
|
|
||||||
@@ -52,6 +52,7 @@ def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> bool | None:
|
|||||||
# Basic Test Fixtures
|
# Basic Test Fixtures
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_guild_id() -> int:
|
def sample_guild_id() -> int:
|
||||||
"""Return a sample Discord guild ID."""
|
"""Return a sample Discord guild ID."""
|
||||||
@@ -80,11 +81,12 @@ def sample_owner_id() -> int:
|
|||||||
# Configuration Fixtures
|
# Configuration Fixtures
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_settings() -> Settings:
|
def test_settings() -> Settings:
|
||||||
"""Return test configuration settings."""
|
"""Return test configuration settings."""
|
||||||
return Settings(
|
return Settings(
|
||||||
discord_token="test_token_12345678901234567890",
|
discord_token="a" * 60,
|
||||||
discord_prefix="!test",
|
discord_prefix="!test",
|
||||||
database_url="sqlite+aiosqlite:///test.db",
|
database_url="sqlite+aiosqlite:///test.db",
|
||||||
database_pool_min=1,
|
database_pool_min=1,
|
||||||
@@ -101,6 +103,7 @@ def test_settings() -> Settings:
|
|||||||
# Database Fixtures
|
# Database Fixtures
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def test_database(test_settings: Settings) -> AsyncGenerator[Database, None]:
|
async def test_database(test_settings: Settings) -> AsyncGenerator[Database, None]:
|
||||||
"""Create a test database with in-memory SQLite."""
|
"""Create a test database with in-memory SQLite."""
|
||||||
@@ -111,19 +114,26 @@ async def test_database(test_settings: Settings) -> AsyncGenerator[Database, Non
|
|||||||
poolclass=StaticPool,
|
poolclass=StaticPool,
|
||||||
echo=False,
|
echo=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@event.listens_for(engine.sync_engine, "connect")
|
||||||
|
def _enable_sqlite_foreign_keys(dbapi_connection, connection_record):
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
# Create all tables
|
# Create all tables
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
|
await conn.execute(text("PRAGMA foreign_keys=ON"))
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
database = Database(test_settings)
|
database = Database(test_settings)
|
||||||
database._engine = engine
|
database._engine = engine
|
||||||
database._session_factory = async_sessionmaker(
|
database._session_factory = async_sessionmaker(
|
||||||
engine, class_=AsyncSession, expire_on_commit=False
|
engine, class_=AsyncSession, expire_on_commit=False
|
||||||
)
|
)
|
||||||
|
|
||||||
yield database
|
yield database
|
||||||
|
|
||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
@@ -138,10 +148,9 @@ async def db_session(test_database: Database) -> AsyncGenerator[AsyncSession, No
|
|||||||
# Model Fixtures
|
# Model Fixtures
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def test_guild(
|
async def test_guild(db_session: AsyncSession, sample_guild_id: int, sample_owner_id: int) -> Guild:
|
||||||
db_session: AsyncSession, sample_guild_id: int, sample_owner_id: int
|
|
||||||
) -> Guild:
|
|
||||||
"""Create a test guild with settings."""
|
"""Create a test guild with settings."""
|
||||||
guild = Guild(
|
guild = Guild(
|
||||||
id=sample_guild_id,
|
id=sample_guild_id,
|
||||||
@@ -150,7 +159,7 @@ async def test_guild(
|
|||||||
premium=False,
|
premium=False,
|
||||||
)
|
)
|
||||||
db_session.add(guild)
|
db_session.add(guild)
|
||||||
|
|
||||||
# Create associated settings
|
# Create associated settings
|
||||||
settings = GuildSettings(
|
settings = GuildSettings(
|
||||||
guild_id=sample_guild_id,
|
guild_id=sample_guild_id,
|
||||||
@@ -160,7 +169,7 @@ async def test_guild(
|
|||||||
verification_enabled=False,
|
verification_enabled=False,
|
||||||
)
|
)
|
||||||
db_session.add(settings)
|
db_session.add(settings)
|
||||||
|
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
await db_session.refresh(guild)
|
await db_session.refresh(guild)
|
||||||
return guild
|
return guild
|
||||||
@@ -187,10 +196,7 @@ async def test_banned_word(
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def test_moderation_log(
|
async def test_moderation_log(
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession, test_guild: Guild, sample_user_id: int, sample_moderator_id: int
|
||||||
test_guild: Guild,
|
|
||||||
sample_user_id: int,
|
|
||||||
sample_moderator_id: int
|
|
||||||
) -> ModerationLog:
|
) -> ModerationLog:
|
||||||
"""Create a test moderation log entry."""
|
"""Create a test moderation log entry."""
|
||||||
mod_log = ModerationLog(
|
mod_log = ModerationLog(
|
||||||
@@ -211,10 +217,7 @@ async def test_moderation_log(
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def test_strike(
|
async def test_strike(
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession, test_guild: Guild, sample_user_id: int, sample_moderator_id: int
|
||||||
test_guild: Guild,
|
|
||||||
sample_user_id: int,
|
|
||||||
sample_moderator_id: int
|
|
||||||
) -> Strike:
|
) -> Strike:
|
||||||
"""Create a test strike."""
|
"""Create a test strike."""
|
||||||
strike = Strike(
|
strike = Strike(
|
||||||
@@ -236,6 +239,7 @@ async def test_strike(
|
|||||||
# Discord Mock Fixtures
|
# Discord Mock Fixtures
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_discord_user(sample_user_id: int) -> MagicMock:
|
def mock_discord_user(sample_user_id: int) -> MagicMock:
|
||||||
"""Create a mock Discord user."""
|
"""Create a mock Discord user."""
|
||||||
@@ -261,7 +265,7 @@ def mock_discord_member(mock_discord_user: MagicMock) -> MagicMock:
|
|||||||
member.avatar = mock_discord_user.avatar
|
member.avatar = mock_discord_user.avatar
|
||||||
member.bot = mock_discord_user.bot
|
member.bot = mock_discord_user.bot
|
||||||
member.send = mock_discord_user.send
|
member.send = mock_discord_user.send
|
||||||
|
|
||||||
# Member-specific attributes
|
# Member-specific attributes
|
||||||
member.guild = MagicMock()
|
member.guild = MagicMock()
|
||||||
member.top_role = MagicMock()
|
member.top_role = MagicMock()
|
||||||
@@ -271,7 +275,7 @@ def mock_discord_member(mock_discord_user: MagicMock) -> MagicMock:
|
|||||||
member.kick = AsyncMock()
|
member.kick = AsyncMock()
|
||||||
member.ban = AsyncMock()
|
member.ban = AsyncMock()
|
||||||
member.timeout = AsyncMock()
|
member.timeout = AsyncMock()
|
||||||
|
|
||||||
return member
|
return member
|
||||||
|
|
||||||
|
|
||||||
@@ -284,14 +288,14 @@ def mock_discord_guild(sample_guild_id: int, sample_owner_id: int) -> MagicMock:
|
|||||||
guild.owner_id = sample_owner_id
|
guild.owner_id = sample_owner_id
|
||||||
guild.member_count = 100
|
guild.member_count = 100
|
||||||
guild.premium_tier = 0
|
guild.premium_tier = 0
|
||||||
|
|
||||||
# Methods
|
# Methods
|
||||||
guild.get_member = MagicMock(return_value=None)
|
guild.get_member = MagicMock(return_value=None)
|
||||||
guild.get_channel = MagicMock(return_value=None)
|
guild.get_channel = MagicMock(return_value=None)
|
||||||
guild.leave = AsyncMock()
|
guild.leave = AsyncMock()
|
||||||
guild.ban = AsyncMock()
|
guild.ban = AsyncMock()
|
||||||
guild.unban = AsyncMock()
|
guild.unban = AsyncMock()
|
||||||
|
|
||||||
return guild
|
return guild
|
||||||
|
|
||||||
|
|
||||||
@@ -327,9 +331,7 @@ def mock_discord_message(
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_discord_context(
|
def mock_discord_context(
|
||||||
mock_discord_member: MagicMock,
|
mock_discord_member: MagicMock, mock_discord_guild: MagicMock, mock_discord_channel: MagicMock
|
||||||
mock_discord_guild: MagicMock,
|
|
||||||
mock_discord_channel: MagicMock
|
|
||||||
) -> MagicMock:
|
) -> MagicMock:
|
||||||
"""Create a mock Discord command context."""
|
"""Create a mock Discord command context."""
|
||||||
ctx = MagicMock()
|
ctx = MagicMock()
|
||||||
@@ -345,6 +347,7 @@ def mock_discord_context(
|
|||||||
# Bot and Service Fixtures
|
# Bot and Service Fixtures
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_bot(test_database: Database) -> MagicMock:
|
def mock_bot(test_database: Database) -> MagicMock:
|
||||||
"""Create a mock GuardDen bot."""
|
"""Create a mock GuardDen bot."""
|
||||||
@@ -363,6 +366,7 @@ def mock_bot(test_database: Database) -> MagicMock:
|
|||||||
# Test Environment Setup
|
# Test Environment Setup
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def setup_test_environment() -> None:
|
def setup_test_environment() -> None:
|
||||||
"""Set up test environment variables."""
|
"""Set up test environment variables."""
|
||||||
|
|||||||
@@ -3,7 +3,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from guardden.config import Settings, _parse_id_list, _validate_discord_id, normalize_domain
|
from guardden.config import Settings, _parse_id_list, _validate_discord_id
|
||||||
|
from guardden.services.automod import normalize_domain
|
||||||
|
|
||||||
|
|
||||||
class TestDiscordIdValidation:
|
class TestDiscordIdValidation:
|
||||||
@@ -17,7 +18,7 @@ class TestDiscordIdValidation:
|
|||||||
"1234567890123456789", # 19 digits
|
"1234567890123456789", # 19 digits
|
||||||
123456789012345678, # int format
|
123456789012345678, # int format
|
||||||
]
|
]
|
||||||
|
|
||||||
for valid_id in valid_ids:
|
for valid_id in valid_ids:
|
||||||
result = _validate_discord_id(valid_id)
|
result = _validate_discord_id(valid_id)
|
||||||
assert isinstance(result, int)
|
assert isinstance(result, int)
|
||||||
@@ -35,7 +36,7 @@ class TestDiscordIdValidation:
|
|||||||
"0", # zero
|
"0", # zero
|
||||||
"-123456789012345678", # negative
|
"-123456789012345678", # negative
|
||||||
]
|
]
|
||||||
|
|
||||||
for invalid_id in invalid_ids:
|
for invalid_id in invalid_ids:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
_validate_discord_id(invalid_id)
|
_validate_discord_id(invalid_id)
|
||||||
@@ -45,7 +46,7 @@ class TestDiscordIdValidation:
|
|||||||
# Too small (before Discord existed)
|
# Too small (before Discord existed)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
_validate_discord_id("99999999999999999")
|
_validate_discord_id("99999999999999999")
|
||||||
|
|
||||||
# Too large (exceeds 64-bit limit)
|
# Too large (exceeds 64-bit limit)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
_validate_discord_id("99999999999999999999")
|
_validate_discord_id("99999999999999999999")
|
||||||
@@ -64,7 +65,7 @@ class TestIdListParsing:
|
|||||||
("", []),
|
("", []),
|
||||||
(None, []),
|
(None, []),
|
||||||
]
|
]
|
||||||
|
|
||||||
for input_value, expected in test_cases:
|
for input_value, expected in test_cases:
|
||||||
result = _parse_id_list(input_value)
|
result = _parse_id_list(input_value)
|
||||||
assert result == expected
|
assert result == expected
|
||||||
@@ -89,7 +90,7 @@ class TestIdListParsing:
|
|||||||
"123456789012345678\n234567890123456789", # newline
|
"123456789012345678\n234567890123456789", # newline
|
||||||
"123456789012345678\r234567890123456789", # carriage return
|
"123456789012345678\r234567890123456789", # carriage return
|
||||||
]
|
]
|
||||||
|
|
||||||
for malicious_input in malicious_inputs:
|
for malicious_input in malicious_inputs:
|
||||||
result = _parse_id_list(malicious_input)
|
result = _parse_id_list(malicious_input)
|
||||||
# Should filter out malicious entries
|
# Should filter out malicious entries
|
||||||
@@ -106,7 +107,7 @@ class TestSettingsValidation:
|
|||||||
"Bot.MTIzNDU2Nzg5MDEyMzQ1Njc4.some_long_token_string_here",
|
"Bot.MTIzNDU2Nzg5MDEyMzQ1Njc4.some_long_token_string_here",
|
||||||
"a" * 60, # minimum reasonable length
|
"a" * 60, # minimum reasonable length
|
||||||
]
|
]
|
||||||
|
|
||||||
for token in valid_tokens:
|
for token in valid_tokens:
|
||||||
settings = Settings(discord_token=token)
|
settings = Settings(discord_token=token)
|
||||||
assert settings.discord_token.get_secret_value() == token
|
assert settings.discord_token.get_secret_value() == token
|
||||||
@@ -119,7 +120,7 @@ class TestSettingsValidation:
|
|||||||
"token with spaces", # contains spaces
|
"token with spaces", # contains spaces
|
||||||
"token\nwith\nnewlines", # contains newlines
|
"token\nwith\nnewlines", # contains newlines
|
||||||
]
|
]
|
||||||
|
|
||||||
for token in invalid_tokens:
|
for token in invalid_tokens:
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
Settings(discord_token=token)
|
Settings(discord_token=token)
|
||||||
@@ -131,7 +132,7 @@ class TestSettingsValidation:
|
|||||||
settings = Settings(
|
settings = Settings(
|
||||||
discord_token="valid_token_" + "a" * 50,
|
discord_token="valid_token_" + "a" * 50,
|
||||||
ai_provider="anthropic",
|
ai_provider="anthropic",
|
||||||
anthropic_api_key=valid_key
|
anthropic_api_key=valid_key,
|
||||||
)
|
)
|
||||||
assert settings.anthropic_api_key.get_secret_value() == valid_key
|
assert settings.anthropic_api_key.get_secret_value() == valid_key
|
||||||
|
|
||||||
@@ -140,23 +141,23 @@ class TestSettingsValidation:
|
|||||||
Settings(
|
Settings(
|
||||||
discord_token="valid_token_" + "a" * 50,
|
discord_token="valid_token_" + "a" * 50,
|
||||||
ai_provider="anthropic",
|
ai_provider="anthropic",
|
||||||
anthropic_api_key="short"
|
anthropic_api_key="short",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_configuration_validation_ai_provider(self):
|
def test_configuration_validation_ai_provider(self):
|
||||||
"""Test AI provider configuration validation."""
|
"""Test AI provider configuration validation."""
|
||||||
settings = Settings(discord_token="valid_token_" + "a" * 50)
|
settings = Settings(discord_token="valid_token_" + "a" * 50)
|
||||||
|
|
||||||
# Should pass with no AI provider
|
# Should pass with no AI provider
|
||||||
settings.ai_provider = "none"
|
settings.ai_provider = "none"
|
||||||
settings.validate_configuration()
|
settings.validate_configuration()
|
||||||
|
|
||||||
# Should fail with anthropic but no key
|
# Should fail with anthropic but no key
|
||||||
settings.ai_provider = "anthropic"
|
settings.ai_provider = "anthropic"
|
||||||
settings.anthropic_api_key = None
|
settings.anthropic_api_key = None
|
||||||
with pytest.raises(ValueError, match="GUARDDEN_ANTHROPIC_API_KEY is required"):
|
with pytest.raises(ValueError, match="GUARDDEN_ANTHROPIC_API_KEY is required"):
|
||||||
settings.validate_configuration()
|
settings.validate_configuration()
|
||||||
|
|
||||||
# Should pass with anthropic and key
|
# Should pass with anthropic and key
|
||||||
settings.anthropic_api_key = "sk-" + "a" * 50
|
settings.anthropic_api_key = "sk-" + "a" * 50
|
||||||
settings.validate_configuration()
|
settings.validate_configuration()
|
||||||
@@ -164,13 +165,13 @@ class TestSettingsValidation:
|
|||||||
def test_configuration_validation_database_pool(self):
|
def test_configuration_validation_database_pool(self):
|
||||||
"""Test database pool configuration validation."""
|
"""Test database pool configuration validation."""
|
||||||
settings = Settings(discord_token="valid_token_" + "a" * 50)
|
settings = Settings(discord_token="valid_token_" + "a" * 50)
|
||||||
|
|
||||||
# Should fail with min > max
|
# Should fail with min > max
|
||||||
settings.database_pool_min = 10
|
settings.database_pool_min = 10
|
||||||
settings.database_pool_max = 5
|
settings.database_pool_max = 5
|
||||||
with pytest.raises(ValueError, match="database_pool_min cannot be greater"):
|
with pytest.raises(ValueError, match="database_pool_min cannot be greater"):
|
||||||
settings.validate_configuration()
|
settings.validate_configuration()
|
||||||
|
|
||||||
# Should fail with min < 1
|
# Should fail with min < 1
|
||||||
settings.database_pool_min = 0
|
settings.database_pool_min = 0
|
||||||
settings.database_pool_max = 5
|
settings.database_pool_max = 5
|
||||||
@@ -190,7 +191,7 @@ class TestSecurityImprovements:
|
|||||||
"123456789012345678\x00\x01\x02",
|
"123456789012345678\x00\x01\x02",
|
||||||
"123456789012345678<script>alert('xss')</script>",
|
"123456789012345678<script>alert('xss')</script>",
|
||||||
]
|
]
|
||||||
|
|
||||||
for attempt in injection_attempts:
|
for attempt in injection_attempts:
|
||||||
# Should either raise an error or filter out the malicious input
|
# Should either raise an error or filter out the malicious input
|
||||||
try:
|
try:
|
||||||
@@ -205,33 +206,41 @@ class TestSecurityImprovements:
|
|||||||
def test_settings_with_malicious_env_vars(self):
|
def test_settings_with_malicious_env_vars(self):
|
||||||
"""Test that settings handle malicious environment variables."""
|
"""Test that settings handle malicious environment variables."""
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# Save original values
|
# Save original values
|
||||||
original_guilds = os.environ.get("GUARDDEN_ALLOWED_GUILDS")
|
original_guilds = os.environ.get("GUARDDEN_ALLOWED_GUILDS")
|
||||||
original_owners = os.environ.get("GUARDDEN_OWNER_IDS")
|
original_owners = os.environ.get("GUARDDEN_OWNER_IDS")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Set malicious environment variables
|
# Set malicious environment variables
|
||||||
os.environ["GUARDDEN_ALLOWED_GUILDS"] = "123456789012345678\x00,malicious"
|
try:
|
||||||
os.environ["GUARDDEN_OWNER_IDS"] = "234567890123456789\n567890123456789012"
|
os.environ["GUARDDEN_ALLOWED_GUILDS"] = "123456789012345678\x00,malicious"
|
||||||
|
except ValueError:
|
||||||
|
os.environ["GUARDDEN_ALLOWED_GUILDS"] = "123456789012345678,malicious"
|
||||||
|
try:
|
||||||
|
os.environ["GUARDDEN_OWNER_IDS"] = "234567890123456789\n567890123456789012"
|
||||||
|
except ValueError:
|
||||||
|
os.environ["GUARDDEN_OWNER_IDS"] = "234567890123456789,567890123456789012"
|
||||||
|
|
||||||
settings = Settings(discord_token="valid_token_" + "a" * 50)
|
settings = Settings(discord_token="valid_token_" + "a" * 50)
|
||||||
|
|
||||||
# Should filter out malicious entries
|
# Should filter out malicious entries
|
||||||
assert len(settings.allowed_guilds) <= 1
|
assert len(settings.allowed_guilds) <= 1
|
||||||
assert len(settings.owner_ids) <= 1
|
assert len(settings.owner_ids) <= 1
|
||||||
|
|
||||||
# Valid IDs should be preserved
|
# Valid IDs should be preserved
|
||||||
assert 123456789012345678 in settings.allowed_guilds or len(settings.allowed_guilds) == 0
|
assert (
|
||||||
|
123456789012345678 in settings.allowed_guilds or len(settings.allowed_guilds) == 0
|
||||||
|
)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Restore original values
|
# Restore original values
|
||||||
if original_guilds is not None:
|
if original_guilds is not None:
|
||||||
os.environ["GUARDDEN_ALLOWED_GUILDS"] = original_guilds
|
os.environ["GUARDDEN_ALLOWED_GUILDS"] = original_guilds
|
||||||
else:
|
else:
|
||||||
os.environ.pop("GUARDDEN_ALLOWED_GUILDS", None)
|
os.environ.pop("GUARDDEN_ALLOWED_GUILDS", None)
|
||||||
|
|
||||||
if original_owners is not None:
|
if original_owners is not None:
|
||||||
os.environ["GUARDDEN_OWNER_IDS"] = original_owners
|
os.environ["GUARDDEN_OWNER_IDS"] = original_owners
|
||||||
else:
|
else:
|
||||||
os.environ.pop("GUARDDEN_OWNER_IDS", None)
|
os.environ.pop("GUARDDEN_OWNER_IDS", None)
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
"""Tests for database integration and models."""
|
"""Tests for database integration and models."""
|
||||||
|
|
||||||
import pytest
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from guardden.models.guild import Guild, GuildSettings, BannedWord
|
from guardden.models.guild import BannedWord, Guild, GuildSettings
|
||||||
from guardden.models.moderation import ModerationLog, Strike, UserNote
|
from guardden.models.moderation import ModerationLog, Strike, UserNote
|
||||||
from guardden.services.database import Database
|
from guardden.services.database import Database
|
||||||
|
|
||||||
@@ -21,7 +22,7 @@ class TestDatabaseModels:
|
|||||||
premium=False,
|
premium=False,
|
||||||
)
|
)
|
||||||
db_session.add(guild)
|
db_session.add(guild)
|
||||||
|
|
||||||
settings = GuildSettings(
|
settings = GuildSettings(
|
||||||
guild_id=sample_guild_id,
|
guild_id=sample_guild_id,
|
||||||
prefix="!",
|
prefix="!",
|
||||||
@@ -29,13 +30,13 @@ class TestDatabaseModels:
|
|||||||
ai_moderation_enabled=False,
|
ai_moderation_enabled=False,
|
||||||
)
|
)
|
||||||
db_session.add(settings)
|
db_session.add(settings)
|
||||||
|
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
# Test guild was created
|
# Test guild was created
|
||||||
result = await db_session.execute(select(Guild).where(Guild.id == sample_guild_id))
|
result = await db_session.execute(select(Guild).where(Guild.id == sample_guild_id))
|
||||||
created_guild = result.scalar_one()
|
created_guild = result.scalar_one()
|
||||||
|
|
||||||
assert created_guild.id == sample_guild_id
|
assert created_guild.id == sample_guild_id
|
||||||
assert created_guild.name == "Test Guild"
|
assert created_guild.name == "Test Guild"
|
||||||
assert created_guild.owner_id == sample_owner_id
|
assert created_guild.owner_id == sample_owner_id
|
||||||
@@ -44,11 +45,9 @@ class TestDatabaseModels:
|
|||||||
async def test_guild_settings_relationship(self, test_guild, db_session):
|
async def test_guild_settings_relationship(self, test_guild, db_session):
|
||||||
"""Test guild-settings relationship."""
|
"""Test guild-settings relationship."""
|
||||||
# Load guild with settings
|
# Load guild with settings
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(select(Guild).where(Guild.id == test_guild.id))
|
||||||
select(Guild).where(Guild.id == test_guild.id)
|
|
||||||
)
|
|
||||||
guild_with_settings = result.scalar_one()
|
guild_with_settings = result.scalar_one()
|
||||||
|
|
||||||
# Test relationship loading
|
# Test relationship loading
|
||||||
await db_session.refresh(guild_with_settings, ["settings"])
|
await db_session.refresh(guild_with_settings, ["settings"])
|
||||||
assert guild_with_settings.settings is not None
|
assert guild_with_settings.settings is not None
|
||||||
@@ -67,24 +66,20 @@ class TestDatabaseModels:
|
|||||||
)
|
)
|
||||||
db_session.add(banned_word)
|
db_session.add(banned_word)
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
# Verify creation
|
# Verify creation
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(
|
||||||
select(BannedWord).where(BannedWord.guild_id == test_guild.id)
|
select(BannedWord).where(BannedWord.guild_id == test_guild.id)
|
||||||
)
|
)
|
||||||
created_word = result.scalar_one()
|
created_word = result.scalar_one()
|
||||||
|
|
||||||
assert created_word.pattern == "testbadword"
|
assert created_word.pattern == "testbadword"
|
||||||
assert not created_word.is_regex
|
assert not created_word.is_regex
|
||||||
assert created_word.action == "delete"
|
assert created_word.action == "delete"
|
||||||
assert created_word.added_by == sample_moderator_id
|
assert created_word.added_by == sample_moderator_id
|
||||||
|
|
||||||
async def test_moderation_log_creation(
|
async def test_moderation_log_creation(
|
||||||
self,
|
self, test_guild, db_session, sample_user_id, sample_moderator_id
|
||||||
test_guild,
|
|
||||||
db_session,
|
|
||||||
sample_user_id,
|
|
||||||
sample_moderator_id
|
|
||||||
):
|
):
|
||||||
"""Test moderation log creation."""
|
"""Test moderation log creation."""
|
||||||
mod_log = ModerationLog(
|
mod_log = ModerationLog(
|
||||||
@@ -99,24 +94,20 @@ class TestDatabaseModels:
|
|||||||
)
|
)
|
||||||
db_session.add(mod_log)
|
db_session.add(mod_log)
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
# Verify creation
|
# Verify creation
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(
|
||||||
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
|
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
|
||||||
)
|
)
|
||||||
created_log = result.scalar_one()
|
created_log = result.scalar_one()
|
||||||
|
|
||||||
assert created_log.action == "ban"
|
assert created_log.action == "ban"
|
||||||
assert created_log.target_id == sample_user_id
|
assert created_log.target_id == sample_user_id
|
||||||
assert created_log.moderator_id == sample_moderator_id
|
assert created_log.moderator_id == sample_moderator_id
|
||||||
assert not created_log.is_automatic
|
assert not created_log.is_automatic
|
||||||
|
|
||||||
async def test_strike_creation(
|
async def test_strike_creation(
|
||||||
self,
|
self, test_guild, db_session, sample_user_id, sample_moderator_id
|
||||||
test_guild,
|
|
||||||
db_session,
|
|
||||||
sample_user_id,
|
|
||||||
sample_moderator_id
|
|
||||||
):
|
):
|
||||||
"""Test strike creation and tracking."""
|
"""Test strike creation and tracking."""
|
||||||
strike = Strike(
|
strike = Strike(
|
||||||
@@ -130,26 +121,19 @@ class TestDatabaseModels:
|
|||||||
)
|
)
|
||||||
db_session.add(strike)
|
db_session.add(strike)
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
# Verify creation
|
# Verify creation
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(
|
||||||
select(Strike).where(
|
select(Strike).where(Strike.guild_id == test_guild.id, Strike.user_id == sample_user_id)
|
||||||
Strike.guild_id == test_guild.id,
|
|
||||||
Strike.user_id == sample_user_id
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
created_strike = result.scalar_one()
|
created_strike = result.scalar_one()
|
||||||
|
|
||||||
assert created_strike.points == 1
|
assert created_strike.points == 1
|
||||||
assert created_strike.is_active
|
assert created_strike.is_active
|
||||||
assert created_strike.user_id == sample_user_id
|
assert created_strike.user_id == sample_user_id
|
||||||
|
|
||||||
async def test_cascade_deletion(
|
async def test_cascade_deletion(
|
||||||
self,
|
self, test_guild, db_session, sample_user_id, sample_moderator_id
|
||||||
test_guild,
|
|
||||||
db_session,
|
|
||||||
sample_user_id,
|
|
||||||
sample_moderator_id
|
|
||||||
):
|
):
|
||||||
"""Test that deleting a guild cascades to related records."""
|
"""Test that deleting a guild cascades to related records."""
|
||||||
# Add some related records
|
# Add some related records
|
||||||
@@ -160,7 +144,7 @@ class TestDatabaseModels:
|
|||||||
action="delete",
|
action="delete",
|
||||||
added_by=sample_moderator_id,
|
added_by=sample_moderator_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
mod_log = ModerationLog(
|
mod_log = ModerationLog(
|
||||||
guild_id=test_guild.id,
|
guild_id=test_guild.id,
|
||||||
target_id=sample_user_id,
|
target_id=sample_user_id,
|
||||||
@@ -171,7 +155,7 @@ class TestDatabaseModels:
|
|||||||
reason="Test warning",
|
reason="Test warning",
|
||||||
is_automatic=False,
|
is_automatic=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
strike = Strike(
|
strike = Strike(
|
||||||
guild_id=test_guild.id,
|
guild_id=test_guild.id,
|
||||||
user_id=sample_user_id,
|
user_id=sample_user_id,
|
||||||
@@ -181,28 +165,26 @@ class TestDatabaseModels:
|
|||||||
points=1,
|
points=1,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
db_session.add_all([banned_word, mod_log, strike])
|
db_session.add_all([banned_word, mod_log, strike])
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
# Delete the guild
|
# Delete the guild
|
||||||
await db_session.delete(test_guild)
|
await db_session.delete(test_guild)
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
# Verify related records were deleted
|
# Verify related records were deleted
|
||||||
banned_words = await db_session.execute(
|
banned_words = await db_session.execute(
|
||||||
select(BannedWord).where(BannedWord.guild_id == test_guild.id)
|
select(BannedWord).where(BannedWord.guild_id == test_guild.id)
|
||||||
)
|
)
|
||||||
assert len(banned_words.scalars().all()) == 0
|
assert len(banned_words.scalars().all()) == 0
|
||||||
|
|
||||||
mod_logs = await db_session.execute(
|
mod_logs = await db_session.execute(
|
||||||
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
|
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
|
||||||
)
|
)
|
||||||
assert len(mod_logs.scalars().all()) == 0
|
assert len(mod_logs.scalars().all()) == 0
|
||||||
|
|
||||||
strikes = await db_session.execute(
|
strikes = await db_session.execute(select(Strike).where(Strike.guild_id == test_guild.id))
|
||||||
select(Strike).where(Strike.guild_id == test_guild.id)
|
|
||||||
)
|
|
||||||
assert len(strikes.scalars().all()) == 0
|
assert len(strikes.scalars().all()) == 0
|
||||||
|
|
||||||
|
|
||||||
@@ -210,11 +192,7 @@ class TestDatabaseIndexes:
|
|||||||
"""Test that database indexes work as expected."""
|
"""Test that database indexes work as expected."""
|
||||||
|
|
||||||
async def test_moderation_log_indexes(
|
async def test_moderation_log_indexes(
|
||||||
self,
|
self, test_guild, db_session, sample_user_id, sample_moderator_id
|
||||||
test_guild,
|
|
||||||
db_session,
|
|
||||||
sample_user_id,
|
|
||||||
sample_moderator_id
|
|
||||||
):
|
):
|
||||||
"""Test moderation log indexing for performance."""
|
"""Test moderation log indexing for performance."""
|
||||||
# Create multiple moderation logs
|
# Create multiple moderation logs
|
||||||
@@ -231,23 +209,23 @@ class TestDatabaseIndexes:
|
|||||||
is_automatic=bool(i % 2),
|
is_automatic=bool(i % 2),
|
||||||
)
|
)
|
||||||
logs.append(log)
|
logs.append(log)
|
||||||
|
|
||||||
db_session.add_all(logs)
|
db_session.add_all(logs)
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
# Test queries that should use indexes
|
# Test queries that should use indexes
|
||||||
# Query by guild_id
|
# Query by guild_id
|
||||||
guild_logs = await db_session.execute(
|
guild_logs = await db_session.execute(
|
||||||
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
|
select(ModerationLog).where(ModerationLog.guild_id == test_guild.id)
|
||||||
)
|
)
|
||||||
assert len(guild_logs.scalars().all()) == 10
|
assert len(guild_logs.scalars().all()) == 10
|
||||||
|
|
||||||
# Query by target_id
|
# Query by target_id
|
||||||
target_logs = await db_session.execute(
|
target_logs = await db_session.execute(
|
||||||
select(ModerationLog).where(ModerationLog.target_id == sample_user_id)
|
select(ModerationLog).where(ModerationLog.target_id == sample_user_id)
|
||||||
)
|
)
|
||||||
assert len(target_logs.scalars().all()) == 1
|
assert len(target_logs.scalars().all()) == 1
|
||||||
|
|
||||||
# Query by is_automatic
|
# Query by is_automatic
|
||||||
auto_logs = await db_session.execute(
|
auto_logs = await db_session.execute(
|
||||||
select(ModerationLog).where(ModerationLog.is_automatic == True)
|
select(ModerationLog).where(ModerationLog.is_automatic == True)
|
||||||
@@ -255,11 +233,7 @@ class TestDatabaseIndexes:
|
|||||||
assert len(auto_logs.scalars().all()) == 5
|
assert len(auto_logs.scalars().all()) == 5
|
||||||
|
|
||||||
async def test_strike_indexes(
|
async def test_strike_indexes(
|
||||||
self,
|
self, test_guild, db_session, sample_user_id, sample_moderator_id
|
||||||
test_guild,
|
|
||||||
db_session,
|
|
||||||
sample_user_id,
|
|
||||||
sample_moderator_id
|
|
||||||
):
|
):
|
||||||
"""Test strike indexing for performance."""
|
"""Test strike indexing for performance."""
|
||||||
# Create multiple strikes
|
# Create multiple strikes
|
||||||
@@ -275,18 +249,15 @@ class TestDatabaseIndexes:
|
|||||||
is_active=bool(i % 2),
|
is_active=bool(i % 2),
|
||||||
)
|
)
|
||||||
strikes.append(strike)
|
strikes.append(strike)
|
||||||
|
|
||||||
db_session.add_all(strikes)
|
db_session.add_all(strikes)
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
# Test active strikes query
|
# Test active strikes query
|
||||||
active_strikes = await db_session.execute(
|
active_strikes = await db_session.execute(
|
||||||
select(Strike).where(
|
select(Strike).where(Strike.guild_id == test_guild.id, Strike.is_active == True)
|
||||||
Strike.guild_id == test_guild.id,
|
|
||||||
Strike.is_active == True
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
assert len(active_strikes.scalars().all()) == 3 # indices 1, 3
|
assert len(active_strikes.scalars().all()) == 2 # indices 1, 3
|
||||||
|
|
||||||
|
|
||||||
class TestDatabaseSecurity:
|
class TestDatabaseSecurity:
|
||||||
@@ -304,11 +275,9 @@ class TestDatabaseSecurity:
|
|||||||
)
|
)
|
||||||
db_session.add(guild)
|
db_session.add(guild)
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
|
||||||
# Verify it was stored correctly
|
# Verify it was stored correctly
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(select(Guild).where(Guild.id == valid_guild_id))
|
||||||
select(Guild).where(Guild.id == valid_guild_id)
|
|
||||||
)
|
|
||||||
stored_guild = result.scalar_one()
|
stored_guild = result.scalar_one()
|
||||||
assert stored_guild.id == valid_guild_id
|
assert stored_guild.id == valid_guild_id
|
||||||
|
|
||||||
@@ -321,13 +290,11 @@ class TestDatabaseSecurity:
|
|||||||
"' OR '1'='1",
|
"' OR '1'='1",
|
||||||
"<script>alert('xss')</script>",
|
"<script>alert('xss')</script>",
|
||||||
]
|
]
|
||||||
|
|
||||||
for malicious_input in malicious_inputs:
|
for malicious_input in malicious_inputs:
|
||||||
# Try to use malicious input in a query
|
# Try to use malicious input in a query
|
||||||
# SQLAlchemy should prevent injection through parameterized queries
|
# SQLAlchemy should prevent injection through parameterized queries
|
||||||
result = await db_session.execute(
|
result = await db_session.execute(select(Guild).where(Guild.name == malicious_input))
|
||||||
select(Guild).where(Guild.name == malicious_input)
|
|
||||||
)
|
|
||||||
# Should not find anything (and not crash)
|
# Should not find anything (and not crash)
|
||||||
assert result.scalar_one_or_none() is None
|
assert result.scalar_one_or_none() is None
|
||||||
|
|
||||||
@@ -343,4 +310,5 @@ class TestDatabaseSecurity:
|
|||||||
added_by=123456789012345678,
|
added_by=123456789012345678,
|
||||||
)
|
)
|
||||||
db_session.add(banned_word)
|
db_session.add(banned_word)
|
||||||
await db_session.commit()
|
await db_session.commit()
|
||||||
|
await db_session.rollback()
|
||||||
|
|||||||
Reference in New Issue
Block a user