Implement GuardDen Discord moderation bot
Features: - Core moderation: warn, kick, ban, timeout, strike system - Automod: banned words filter, scam detection, anti-spam, link filtering - AI moderation: Claude/OpenAI integration, NSFW detection, phishing analysis - Verification system: button, captcha, math, emoji challenges - Rate limiting system with configurable scopes - Event logging: joins, leaves, message edits/deletes, voice activity - Per-guild configuration with caching - Docker deployment support Bug fixes applied: - Fixed await on session.delete() in guild_config.py - Fixed memory leak in AI moderation message tracking (use deque) - Added error handling to bot shutdown - Added error handling to timeout command - Removed unused Literal import - Added prefix validation - Added image analysis limit (3 per message) - Fixed test mock for SQLAlchemy model
This commit is contained in:
21
.env.example
Normal file
21
.env.example
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# Discord Bot Configuration
|
||||||
|
GUARDDEN_DISCORD_TOKEN=your_discord_bot_token_here
|
||||||
|
GUARDDEN_DISCORD_PREFIX=!
|
||||||
|
|
||||||
|
# Database Configuration (for local development without Docker)
|
||||||
|
GUARDDEN_DATABASE_URL=postgresql://guardden:guardden@localhost:5432/guardden
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
GUARDDEN_LOG_LEVEL=INFO
|
||||||
|
|
||||||
|
# AI Configuration (optional)
|
||||||
|
# Options: none, anthropic, openai
|
||||||
|
GUARDDEN_AI_PROVIDER=none
|
||||||
|
|
||||||
|
# Anthropic API key (required if AI_PROVIDER=anthropic)
|
||||||
|
# Get your key at: https://console.anthropic.com/
|
||||||
|
GUARDDEN_ANTHROPIC_API_KEY=
|
||||||
|
|
||||||
|
# OpenAI API key (required if AI_PROVIDER=openai)
|
||||||
|
# Get your key at: https://platform.openai.com/api-keys
|
||||||
|
GUARDDEN_OPENAI_API_KEY=
|
||||||
56
.gitignore
vendored
Normal file
56
.gitignore
vendored
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env/
|
||||||
|
.venv/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
|
||||||
|
# Data
|
||||||
|
data/
|
||||||
|
*.db
|
||||||
|
*.sqlite3
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
logs/
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
.pytest_cache/
|
||||||
|
.mypy_cache/
|
||||||
|
|
||||||
|
# Docker
|
||||||
|
docker-compose.override.yml
|
||||||
103
CLAUDE.md
Normal file
103
CLAUDE.md
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
# CLAUDE.md
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
|
||||||
|
GuardDen is a Discord moderation bot built with discord.py, PostgreSQL, and optional AI integration (Claude/OpenAI). Self-hosted with Docker support.
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install dependencies
|
||||||
|
pip install -e ".[dev,ai]"
|
||||||
|
|
||||||
|
# Run the bot
|
||||||
|
python -m guardden
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
pytest
|
||||||
|
|
||||||
|
# Run single test
|
||||||
|
pytest tests/test_verification.py::TestVerificationService::test_verify_correct
|
||||||
|
|
||||||
|
# Lint and format
|
||||||
|
ruff check src tests
|
||||||
|
ruff format src tests
|
||||||
|
|
||||||
|
# Type checking
|
||||||
|
mypy src
|
||||||
|
|
||||||
|
# Docker deployment
|
||||||
|
docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
- `src/guardden/bot.py` - Main bot class (`GuardDen`) extending `commands.Bot`, manages lifecycle and services
|
||||||
|
- `src/guardden/config.py` - Pydantic settings loaded from environment variables (prefix: `GUARDDEN_`)
|
||||||
|
- `src/guardden/models/` - SQLAlchemy 2.0 async models for PostgreSQL
|
||||||
|
- `src/guardden/services/` - Business logic (database, guild config, automod, AI, verification, rate limiting)
|
||||||
|
- `src/guardden/cogs/` - Discord command groups (events, moderation, admin, automod, ai_moderation, verification)
|
||||||
|
|
||||||
|
## Key Patterns
|
||||||
|
|
||||||
|
- All database operations use async SQLAlchemy with `asyncpg`
|
||||||
|
- Guild configurations are cached in `GuildConfigService._cache`
|
||||||
|
- Discord snowflake IDs stored as `BigInteger` in PostgreSQL
|
||||||
|
- Moderation actions logged to `ModerationLog` table with automatic strike escalation
|
||||||
|
- Environment variables: `GUARDDEN_DISCORD_TOKEN`, `GUARDDEN_DATABASE_URL`
|
||||||
|
|
||||||
|
## Automod System
|
||||||
|
|
||||||
|
- `AutomodService` in `services/automod.py` handles rule-based content filtering
|
||||||
|
- Checks run in order: banned words → scam links → spam → invite links
|
||||||
|
- Spam tracking uses per-guild, per-user trackers with automatic cleanup
|
||||||
|
- Scam detection uses compiled regex patterns in `SCAM_PATTERNS` list
|
||||||
|
- Results return `AutomodResult` dataclass with actions to take
|
||||||
|
|
||||||
|
## AI Moderation System
|
||||||
|
|
||||||
|
- `services/ai/` contains provider abstraction and implementations
|
||||||
|
- `AIProvider` base class defines interface: `moderate_text()`, `analyze_image()`, `analyze_phishing()`
|
||||||
|
- `AnthropicProvider` and `OpenAIProvider` implement the interface
|
||||||
|
- `NullProvider` used when AI is disabled (returns empty results)
|
||||||
|
- Factory pattern via `create_ai_provider(provider, api_key)`
|
||||||
|
- `ModerationResult` includes severity scoring based on confidence + category weights
|
||||||
|
- Sensitivity setting (0-100) adjusts thresholds per guild
|
||||||
|
|
||||||
|
## Verification System
|
||||||
|
|
||||||
|
- `VerificationService` in `services/verification.py` manages challenges
|
||||||
|
- Challenge types: button, captcha, math, emoji (via `ChallengeGenerator` classes)
|
||||||
|
- `PendingVerification` tracks user challenges with expiry and attempt limits
|
||||||
|
- Discord UI components in `cogs/verification.py`: `VerifyButton`, `EmojiButton`, `CaptchaModal`
|
||||||
|
- Background task cleans up expired verifications every 5 minutes
|
||||||
|
|
||||||
|
## Rate Limiting System
|
||||||
|
|
||||||
|
- `RateLimiter` in `services/ratelimit.py` provides general-purpose rate limiting
|
||||||
|
- Scopes: USER (global), MEMBER (per-guild), CHANNEL, GUILD
|
||||||
|
- `@ratelimit()` decorator for easy command rate limiting
|
||||||
|
- `get_rate_limiter()` returns singleton instance
|
||||||
|
- Default limits configured for commands, moderation, verification, messages
|
||||||
|
|
||||||
|
## Adding New Cogs
|
||||||
|
|
||||||
|
1. Create file in `src/guardden/cogs/`
|
||||||
|
2. Implement `setup(bot)` async function
|
||||||
|
3. Add cog path to `_load_cogs()` in `bot.py`
|
||||||
|
|
||||||
|
## Adding New AI Provider
|
||||||
|
|
||||||
|
1. Create `src/guardden/services/ai/newprovider.py`
|
||||||
|
2. Implement `AIProvider` abstract class
|
||||||
|
3. Add to factory in `services/ai/factory.py`
|
||||||
|
4. Add config option in `config.py`
|
||||||
|
|
||||||
|
## Adding New Challenge Type
|
||||||
|
|
||||||
|
1. Create new `ChallengeGenerator` subclass in `services/verification.py`
|
||||||
|
2. Add to `ChallengeType` enum
|
||||||
|
3. Register in `VerificationService._generators`
|
||||||
|
4. Create corresponding UI components in `cogs/verification.py` if needed
|
||||||
27
Dockerfile
Normal file
27
Dockerfile
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
FROM python:3.11-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
gcc \
|
||||||
|
libpq-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Copy dependency files
|
||||||
|
COPY pyproject.toml ./
|
||||||
|
|
||||||
|
# Install Python dependencies
|
||||||
|
RUN pip install --no-cache-dir -e .
|
||||||
|
|
||||||
|
# Copy application code
|
||||||
|
COPY src/ ./src/
|
||||||
|
COPY migrations/ ./migrations/
|
||||||
|
COPY alembic.ini ./
|
||||||
|
|
||||||
|
# Create non-root user
|
||||||
|
RUN useradd -m -u 1000 guardden && chown -R guardden:guardden /app
|
||||||
|
USER guardden
|
||||||
|
|
||||||
|
# Run the bot
|
||||||
|
CMD ["python", "-m", "guardden"]
|
||||||
319
README.md
319
README.md
@@ -1,3 +1,320 @@
|
|||||||
# GuardDen
|
# GuardDen
|
||||||
|
|
||||||
GuardDen is a comprehensive Discord moderation bot designed to protect your community while maintaining a warm, welcoming environment. Built with privacy and self-hosting in mind, GuardDen combines AI-powered content filtering with traditional moderation tools to create a safe space for your members.
|
GuardDen is a comprehensive Discord moderation bot designed to protect your community while maintaining a warm, welcoming environment. Built with privacy and self-hosting in mind, GuardDen combines AI-powered content filtering with traditional moderation tools to create a safe space for your members.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
### Core Moderation
|
||||||
|
- **Warn, Kick, Ban, Timeout** - Standard moderation commands with logging
|
||||||
|
- **Strike System** - Configurable point-based system with automatic escalation
|
||||||
|
- **Moderation History** - Track all actions taken against users
|
||||||
|
- **Bulk Message Deletion** - Purge up to 100 messages at once
|
||||||
|
|
||||||
|
### Automod
|
||||||
|
- **Banned Words Filter** - Block words/phrases with regex support
|
||||||
|
- **Scam Detection** - Automatic detection of phishing/scam links
|
||||||
|
- **Anti-Spam** - Rate limiting, duplicate detection, mass mention protection
|
||||||
|
- **Link Filtering** - Block Discord invites and suspicious URLs
|
||||||
|
|
||||||
|
### AI Moderation
|
||||||
|
- **Text Analysis** - AI-powered content moderation using Claude or GPT
|
||||||
|
- **NSFW Image Detection** - Automatic flagging of inappropriate images
|
||||||
|
- **Phishing Analysis** - AI-enhanced detection of scam URLs
|
||||||
|
- **Configurable Sensitivity** - Adjust strictness per server (0-100)
|
||||||
|
|
||||||
|
### Verification System
|
||||||
|
- **Multiple Challenge Types** - Button, captcha, math problems, emoji selection
|
||||||
|
- **Automatic New Member Verification** - Challenge users on join
|
||||||
|
- **Configurable Verified Role** - Auto-assign role on successful verification
|
||||||
|
- **Rate Limited** - Prevents verification spam
|
||||||
|
|
||||||
|
### Logging
|
||||||
|
- Member joins/leaves
|
||||||
|
- Message edits and deletions
|
||||||
|
- Voice channel activity
|
||||||
|
- Ban/unban events
|
||||||
|
- All moderation actions
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
- Python 3.11+
|
||||||
|
- PostgreSQL 15+
|
||||||
|
- Discord Bot Token ([Discord Developer Portal](https://discord.com/developers/applications))
|
||||||
|
- (Optional) Anthropic or OpenAI API key for AI features
|
||||||
|
|
||||||
|
### Docker Deployment (Recommended)
|
||||||
|
|
||||||
|
1. Clone the repository:
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/yourusername/guardden.git
|
||||||
|
cd guardden
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Create your environment file:
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
# Edit .env and add your Discord token
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Start with Docker Compose:
|
||||||
|
```bash
|
||||||
|
docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
### Local Development
|
||||||
|
|
||||||
|
1. Create a virtual environment:
|
||||||
|
```bash
|
||||||
|
python -m venv venv
|
||||||
|
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Install dependencies:
|
||||||
|
```bash
|
||||||
|
pip install -e ".[dev,ai]"
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Set up environment variables:
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
# Edit .env with your configuration
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Start PostgreSQL (or use Docker):
|
||||||
|
```bash
|
||||||
|
docker compose up db -d
|
||||||
|
```
|
||||||
|
|
||||||
|
5. Run the bot:
|
||||||
|
```bash
|
||||||
|
python -m guardden
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
| Variable | Description | Default |
|
||||||
|
|----------|-------------|---------|
|
||||||
|
| `GUARDDEN_DISCORD_TOKEN` | Your Discord bot token | Required |
|
||||||
|
| `GUARDDEN_DISCORD_PREFIX` | Default command prefix | `!` |
|
||||||
|
| `GUARDDEN_DATABASE_URL` | PostgreSQL connection URL | `postgresql://guardden:guardden@localhost:5432/guardden` |
|
||||||
|
| `GUARDDEN_LOG_LEVEL` | Logging level | `INFO` |
|
||||||
|
| `GUARDDEN_AI_PROVIDER` | AI provider (anthropic/openai/none) | `none` |
|
||||||
|
| `GUARDDEN_ANTHROPIC_API_KEY` | Anthropic API key (if using Claude) | - |
|
||||||
|
| `GUARDDEN_OPENAI_API_KEY` | OpenAI API key (if using GPT) | - |
|
||||||
|
|
||||||
|
### Per-Guild Settings
|
||||||
|
|
||||||
|
Each server can configure:
|
||||||
|
- Command prefix
|
||||||
|
- Log channels (general and moderation)
|
||||||
|
- Welcome channel
|
||||||
|
- Mute role and verified role
|
||||||
|
- Automod toggles (spam, links, banned words)
|
||||||
|
- Strike action thresholds
|
||||||
|
- AI moderation settings (enabled, sensitivity, NSFW detection)
|
||||||
|
- Verification settings (type, enabled)
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
|
||||||
|
### Moderation
|
||||||
|
|
||||||
|
| Command | Permission | Description |
|
||||||
|
|---------|------------|-------------|
|
||||||
|
| `!warn <user> [reason]` | Kick Members | Warn a user |
|
||||||
|
| `!strike <user> [points] [reason]` | Kick Members | Add strikes to a user |
|
||||||
|
| `!strikes <user>` | Kick Members | View user's strikes |
|
||||||
|
| `!timeout <user> <duration> [reason]` | Moderate Members | Timeout a user (e.g., 1h, 30m, 7d) |
|
||||||
|
| `!untimeout <user>` | Moderate Members | Remove timeout |
|
||||||
|
| `!kick <user> [reason]` | Kick Members | Kick a user |
|
||||||
|
| `!ban <user> [reason]` | Ban Members | Ban a user |
|
||||||
|
| `!unban <user_id> [reason]` | Ban Members | Unban a user by ID |
|
||||||
|
| `!purge <amount>` | Manage Messages | Delete multiple messages (max 100) |
|
||||||
|
| `!modlogs <user>` | Kick Members | View moderation history |
|
||||||
|
|
||||||
|
### Configuration (Admin only)
|
||||||
|
|
||||||
|
| Command | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| `!config` | View current configuration |
|
||||||
|
| `!config prefix <prefix>` | Set command prefix |
|
||||||
|
| `!config logchannel [#channel]` | Set general log channel |
|
||||||
|
| `!config modlogchannel [#channel]` | Set moderation log channel |
|
||||||
|
| `!config welcomechannel [#channel]` | Set welcome channel |
|
||||||
|
| `!config muterole [@role]` | Set mute role |
|
||||||
|
| `!config automod <true/false>` | Toggle automod |
|
||||||
|
| `!config antispam <true/false>` | Toggle anti-spam |
|
||||||
|
| `!config linkfilter <true/false>` | Toggle link filtering |
|
||||||
|
|
||||||
|
### Banned Words
|
||||||
|
|
||||||
|
| Command | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| `!bannedwords` | List all banned words |
|
||||||
|
| `!bannedwords add <word> [action] [is_regex]` | Add a banned word |
|
||||||
|
| `!bannedwords remove <id>` | Remove a banned word by ID |
|
||||||
|
|
||||||
|
### Automod
|
||||||
|
|
||||||
|
| Command | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| `!automod` | View automod status |
|
||||||
|
| `!automod test <text>` | Test text against filters |
|
||||||
|
|
||||||
|
### AI Moderation (Admin only)
|
||||||
|
|
||||||
|
| Command | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| `!ai` | View AI moderation settings |
|
||||||
|
| `!ai enable` | Enable AI moderation |
|
||||||
|
| `!ai disable` | Disable AI moderation |
|
||||||
|
| `!ai sensitivity <0-100>` | Set AI sensitivity level |
|
||||||
|
| `!ai nsfw <true/false>` | Toggle NSFW image detection |
|
||||||
|
| `!ai analyze <text>` | Test AI analysis on text |
|
||||||
|
|
||||||
|
### Verification (Admin only)
|
||||||
|
|
||||||
|
| Command | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| `!verify` | Request verification (for users) |
|
||||||
|
| `!verify setup` | View verification setup status |
|
||||||
|
| `!verify enable` | Enable verification for new members |
|
||||||
|
| `!verify disable` | Disable verification |
|
||||||
|
| `!verify role @role` | Set the verified role |
|
||||||
|
| `!verify type <type>` | Set verification type (button/captcha/math/emoji) |
|
||||||
|
| `!verify test [type]` | Test a verification challenge |
|
||||||
|
| `!verify reset @user` | Reset verification for a user |
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
guardden/
|
||||||
|
├── src/guardden/
|
||||||
|
│ ├── bot.py # Main bot class
|
||||||
|
│ ├── config.py # Settings management
|
||||||
|
│ ├── cogs/ # Discord command groups
|
||||||
|
│ │ ├── admin.py # Configuration commands
|
||||||
|
│ │ ├── ai_moderation.py # AI-powered moderation
|
||||||
|
│ │ ├── automod.py # Automatic moderation
|
||||||
|
│ │ ├── events.py # Event logging
|
||||||
|
│ │ ├── moderation.py # Moderation commands
|
||||||
|
│ │ └── verification.py # Member verification
|
||||||
|
│ ├── models/ # Database models
|
||||||
|
│ │ ├── guild.py # Guild settings, banned words
|
||||||
|
│ │ └── moderation.py # Logs, strikes, notes
|
||||||
|
│ └── services/ # Business logic
|
||||||
|
│ ├── ai/ # AI provider implementations
|
||||||
|
│ ├── automod.py # Content filtering
|
||||||
|
│ ├── database.py # DB connections
|
||||||
|
│ ├── guild_config.py # Config caching
|
||||||
|
│ ├── ratelimit.py # Rate limiting
|
||||||
|
│ └── verification.py # Verification challenges
|
||||||
|
├── tests/ # Test suite
|
||||||
|
├── migrations/ # Database migrations
|
||||||
|
├── docker-compose.yml # Docker deployment
|
||||||
|
└── pyproject.toml # Dependencies
|
||||||
|
```
|
||||||
|
|
||||||
|
## Verification System
|
||||||
|
|
||||||
|
GuardDen includes a verification system to protect your server from bots and raids.
|
||||||
|
|
||||||
|
### Challenge Types
|
||||||
|
|
||||||
|
| Type | Description |
|
||||||
|
|------|-------------|
|
||||||
|
| `button` | Simple button click (default, easiest) |
|
||||||
|
| `captcha` | Text-based captcha code entry |
|
||||||
|
| `math` | Solve a simple math problem |
|
||||||
|
| `emoji` | Select the correct emoji from options |
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
1. Create a verified role in your server
|
||||||
|
2. Configure the role permissions (verified members get full access)
|
||||||
|
3. Set up verification:
|
||||||
|
```
|
||||||
|
!verify role @Verified
|
||||||
|
!verify type captcha
|
||||||
|
!verify enable
|
||||||
|
```
|
||||||
|
|
||||||
|
### How It Works
|
||||||
|
|
||||||
|
1. New member joins the server
|
||||||
|
2. Bot sends verification challenge via DM (or channel if DMs disabled)
|
||||||
|
3. Member completes the challenge
|
||||||
|
4. Bot assigns the verified role
|
||||||
|
5. Member gains access to the server
|
||||||
|
|
||||||
|
## AI Moderation
|
||||||
|
|
||||||
|
GuardDen supports AI-powered content moderation using either Anthropic's Claude or OpenAI's GPT models.
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
1. Set the AI provider in your environment:
|
||||||
|
```bash
|
||||||
|
GUARDDEN_AI_PROVIDER=anthropic # or "openai"
|
||||||
|
GUARDDEN_ANTHROPIC_API_KEY=sk-ant-... # if using Claude
|
||||||
|
GUARDDEN_OPENAI_API_KEY=sk-... # if using OpenAI
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Enable AI moderation per server:
|
||||||
|
```
|
||||||
|
!ai enable
|
||||||
|
!ai sensitivity 50 # 0=lenient, 100=strict
|
||||||
|
!ai nsfw true # Enable NSFW image detection
|
||||||
|
```
|
||||||
|
|
||||||
|
### Content Categories
|
||||||
|
|
||||||
|
The AI analyzes content for:
|
||||||
|
- **Harassment** - Personal attacks, bullying
|
||||||
|
- **Hate Speech** - Discrimination, slurs
|
||||||
|
- **Sexual Content** - Explicit material
|
||||||
|
- **Violence** - Threats, graphic content
|
||||||
|
- **Self-Harm** - Suicide/self-injury content
|
||||||
|
- **Scams** - Phishing, fraud attempts
|
||||||
|
- **Spam** - Promotional, low-quality content
|
||||||
|
|
||||||
|
### How It Works
|
||||||
|
|
||||||
|
1. Messages are analyzed by the AI provider
|
||||||
|
2. Results include confidence scores and severity ratings
|
||||||
|
3. Actions are taken based on guild sensitivity settings
|
||||||
|
4. All AI actions are logged to the mod log channel
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest
|
||||||
|
pytest -v # Verbose output
|
||||||
|
pytest tests/test_automod.py # Specific file
|
||||||
|
pytest -k "test_scam" # Filter by name
|
||||||
|
```
|
||||||
|
|
||||||
|
### Code Quality
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ruff check src tests # Linting
|
||||||
|
ruff format src tests # Formatting
|
||||||
|
mypy src # Type checking
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT License - see LICENSE file for details.
|
||||||
|
|
||||||
|
## Roadmap
|
||||||
|
|
||||||
|
- [x] AI-powered content moderation (Claude/OpenAI integration)
|
||||||
|
- [x] NSFW image detection
|
||||||
|
- [x] Verification/captcha system
|
||||||
|
- [x] Rate limiting
|
||||||
|
- [ ] Voice channel moderation
|
||||||
|
- [ ] Web dashboard
|
||||||
|
|||||||
43
alembic.ini
Normal file
43
alembic.ini
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
[alembic]
|
||||||
|
script_location = migrations
|
||||||
|
prepend_sys_path = .
|
||||||
|
version_path_separator = os
|
||||||
|
|
||||||
|
# Use async driver
|
||||||
|
sqlalchemy.url = postgresql+asyncpg://guardden:guardden@localhost:5432/guardden
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARN
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARN
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
44
docker-compose.yml
Normal file
44
docker-compose.yml
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
services:
|
||||||
|
bot:
|
||||||
|
build: .
|
||||||
|
container_name: guardden-bot
|
||||||
|
restart: unless-stopped
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
environment:
|
||||||
|
- GUARDDEN_DISCORD_TOKEN=${GUARDDEN_DISCORD_TOKEN}
|
||||||
|
- GUARDDEN_DATABASE_URL=postgresql://guardden:guardden@db:5432/guardden
|
||||||
|
- GUARDDEN_LOG_LEVEL=${GUARDDEN_LOG_LEVEL:-INFO}
|
||||||
|
- GUARDDEN_AI_PROVIDER=${GUARDDEN_AI_PROVIDER:-none}
|
||||||
|
- GUARDDEN_ANTHROPIC_API_KEY=${GUARDDEN_ANTHROPIC_API_KEY:-}
|
||||||
|
- GUARDDEN_OPENAI_API_KEY=${GUARDDEN_OPENAI_API_KEY:-}
|
||||||
|
volumes:
|
||||||
|
- ./data:/app/data
|
||||||
|
networks:
|
||||||
|
- guardden
|
||||||
|
|
||||||
|
db:
|
||||||
|
image: postgres:15-alpine
|
||||||
|
container_name: guardden-db
|
||||||
|
restart: unless-stopped
|
||||||
|
environment:
|
||||||
|
- POSTGRES_USER=guardden
|
||||||
|
- POSTGRES_PASSWORD=guardden
|
||||||
|
- POSTGRES_DB=guardden
|
||||||
|
volumes:
|
||||||
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U guardden -d guardden"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
networks:
|
||||||
|
- guardden
|
||||||
|
|
||||||
|
networks:
|
||||||
|
guardden:
|
||||||
|
driver: bridge
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
65
migrations/env.py
Normal file
65
migrations/env.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
"""Alembic environment configuration."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
from sqlalchemy import pool
|
||||||
|
from sqlalchemy.engine import Connection
|
||||||
|
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||||
|
|
||||||
|
from guardden.models import Base
|
||||||
|
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode."""
|
||||||
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def do_run_migrations(connection: Connection) -> None:
|
||||||
|
"""Run migrations with the given connection."""
|
||||||
|
context.configure(connection=connection, target_metadata=target_metadata)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_async_migrations() -> None:
|
||||||
|
"""Run migrations in 'online' mode with async engine."""
|
||||||
|
connectable = async_engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with connectable.connect() as connection:
|
||||||
|
await connection.run_sync(do_run_migrations)
|
||||||
|
|
||||||
|
await connectable.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode."""
|
||||||
|
asyncio.run(run_async_migrations())
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
26
migrations/script.py.mako
Normal file
26
migrations/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
98
pyproject.toml
Normal file
98
pyproject.toml
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "guardden"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "A comprehensive Discord moderation bot with AI-powered content filtering"
|
||||||
|
readme = "README.md"
|
||||||
|
license = {text = "MIT"}
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
authors = [
|
||||||
|
{name = "GuardDen Team"}
|
||||||
|
]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 3 - Alpha",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
]
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
"discord.py>=2.3.0",
|
||||||
|
"asyncpg>=0.29.0",
|
||||||
|
"pydantic>=2.5.0",
|
||||||
|
"pydantic-settings>=2.1.0",
|
||||||
|
"aiohttp>=3.9.0",
|
||||||
|
"python-dotenv>=1.0.0",
|
||||||
|
"alembic>=1.13.0",
|
||||||
|
"sqlalchemy>=2.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=7.4.0",
|
||||||
|
"pytest-asyncio>=0.23.0",
|
||||||
|
"pytest-cov>=4.1.0",
|
||||||
|
"ruff>=0.1.0",
|
||||||
|
"mypy>=1.7.0",
|
||||||
|
"pre-commit>=3.6.0",
|
||||||
|
]
|
||||||
|
ai = [
|
||||||
|
"anthropic>=0.18.0",
|
||||||
|
"openai>=1.10.0",
|
||||||
|
"pillow>=10.2.0",
|
||||||
|
]
|
||||||
|
voice = [
|
||||||
|
"speechrecognition>=3.10.0",
|
||||||
|
"pydub>=0.25.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
guardden = "guardden.__main__:main"
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["src"]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
target-version = "py311"
|
||||||
|
line-length = 100
|
||||||
|
src = ["src", "tests"]
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = [
|
||||||
|
"E", # pycodestyle errors
|
||||||
|
"W", # pycodestyle warnings
|
||||||
|
"F", # Pyflakes
|
||||||
|
"I", # isort
|
||||||
|
"B", # flake8-bugbear
|
||||||
|
"C4", # flake8-comprehensions
|
||||||
|
"UP", # pyupgrade
|
||||||
|
"ARG", # flake8-unused-arguments
|
||||||
|
"SIM", # flake8-simplify
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"E501", # line too long (handled by formatter)
|
||||||
|
"B008", # do not perform function calls in argument defaults
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.lint.isort]
|
||||||
|
known-first-party = ["guardden"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
testpaths = ["tests"]
|
||||||
|
addopts = "-v --tb=short"
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.11"
|
||||||
|
strict = true
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unused_ignores = true
|
||||||
|
plugins = ["pydantic.mypy"]
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = ["discord.*", "asyncpg.*"]
|
||||||
|
ignore_missing_imports = true
|
||||||
3
src/guardden/__init__.py
Normal file
3
src/guardden/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""GuardDen - A comprehensive Discord moderation bot."""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
40
src/guardden/__main__.py
Normal file
40
src/guardden/__main__.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""Entry point for GuardDen bot."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from guardden.bot import GuardDen
|
||||||
|
from guardden.config import get_settings
|
||||||
|
from guardden.utils import setup_logging
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Run the GuardDen bot."""
|
||||||
|
try:
|
||||||
|
settings = get_settings()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load configuration: {e}", file=sys.stderr)
|
||||||
|
print("Make sure GUARDDEN_DISCORD_TOKEN is set.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
setup_logging(settings.log_level)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
bot = GuardDen(settings)
|
||||||
|
|
||||||
|
async def runner() -> None:
|
||||||
|
async with bot:
|
||||||
|
await bot.start(settings.discord_token.get_secret_value())
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(runner())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Received keyboard interrupt, shutting down...")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Fatal error: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
131
src/guardden/bot.py
Normal file
131
src/guardden/bot.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Main bot class for GuardDen."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from guardden.config import Settings
|
||||||
|
from guardden.services.ai import AIProvider, create_ai_provider
|
||||||
|
from guardden.services.database import Database
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from guardden.services.guild_config import GuildConfigService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GuardDen(commands.Bot):
|
||||||
|
"""The main GuardDen Discord bot."""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings) -> None:
|
||||||
|
self.settings = settings
|
||||||
|
|
||||||
|
intents = discord.Intents.default()
|
||||||
|
intents.message_content = True
|
||||||
|
intents.members = True
|
||||||
|
intents.voice_states = True
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
command_prefix=self._get_prefix,
|
||||||
|
intents=intents,
|
||||||
|
help_command=commands.DefaultHelpCommand(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Services
|
||||||
|
self.database = Database(settings)
|
||||||
|
self.guild_config: "GuildConfigService | None" = None
|
||||||
|
self.ai_provider: AIProvider | None = None
|
||||||
|
|
||||||
|
async def _get_prefix(self, bot: "GuardDen", message: discord.Message) -> list[str]:
|
||||||
|
"""Get the command prefix for a guild."""
|
||||||
|
if not message.guild:
|
||||||
|
return [self.settings.discord_prefix]
|
||||||
|
|
||||||
|
if self.guild_config:
|
||||||
|
config = await self.guild_config.get_config(message.guild.id)
|
||||||
|
if config:
|
||||||
|
return [config.prefix]
|
||||||
|
|
||||||
|
return [self.settings.discord_prefix]
|
||||||
|
|
||||||
|
async def setup_hook(self) -> None:
|
||||||
|
"""Called when the bot is starting up."""
|
||||||
|
logger.info("Starting GuardDen setup...")
|
||||||
|
|
||||||
|
# Connect to database
|
||||||
|
await self.database.connect()
|
||||||
|
await self.database.create_tables()
|
||||||
|
|
||||||
|
# Initialize services
|
||||||
|
from guardden.services.guild_config import GuildConfigService
|
||||||
|
|
||||||
|
self.guild_config = GuildConfigService(self.database)
|
||||||
|
|
||||||
|
# Initialize AI provider
|
||||||
|
api_key = None
|
||||||
|
if self.settings.ai_provider == "anthropic" and self.settings.anthropic_api_key:
|
||||||
|
api_key = self.settings.anthropic_api_key.get_secret_value()
|
||||||
|
elif self.settings.ai_provider == "openai" and self.settings.openai_api_key:
|
||||||
|
api_key = self.settings.openai_api_key.get_secret_value()
|
||||||
|
|
||||||
|
self.ai_provider = create_ai_provider(self.settings.ai_provider, api_key)
|
||||||
|
|
||||||
|
# Load cogs
|
||||||
|
await self._load_cogs()
|
||||||
|
|
||||||
|
logger.info("GuardDen setup complete")
|
||||||
|
|
||||||
|
async def _load_cogs(self) -> None:
|
||||||
|
"""Load all cog extensions."""
|
||||||
|
cogs = [
|
||||||
|
"guardden.cogs.events",
|
||||||
|
"guardden.cogs.moderation",
|
||||||
|
"guardden.cogs.admin",
|
||||||
|
"guardden.cogs.automod",
|
||||||
|
"guardden.cogs.ai_moderation",
|
||||||
|
"guardden.cogs.verification",
|
||||||
|
]
|
||||||
|
|
||||||
|
for cog in cogs:
|
||||||
|
try:
|
||||||
|
await self.load_extension(cog)
|
||||||
|
logger.info(f"Loaded cog: {cog}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load cog {cog}: {e}")
|
||||||
|
|
||||||
|
async def on_ready(self) -> None:
|
||||||
|
"""Called when the bot is fully connected and ready."""
|
||||||
|
if self.user:
|
||||||
|
logger.info(f"Logged in as {self.user} (ID: {self.user.id})")
|
||||||
|
logger.info(f"Connected to {len(self.guilds)} guild(s)")
|
||||||
|
|
||||||
|
# Set presence
|
||||||
|
activity = discord.Activity(
|
||||||
|
type=discord.ActivityType.watching,
|
||||||
|
name="over your community",
|
||||||
|
)
|
||||||
|
await self.change_presence(activity=activity)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Clean up when shutting down."""
|
||||||
|
logger.info("Shutting down GuardDen...")
|
||||||
|
if self.ai_provider:
|
||||||
|
try:
|
||||||
|
await self.ai_provider.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing AI provider: {e}")
|
||||||
|
await self.database.disconnect()
|
||||||
|
await super().close()
|
||||||
|
|
||||||
|
async def on_guild_join(self, guild: discord.Guild) -> None:
|
||||||
|
"""Called when the bot joins a new guild."""
|
||||||
|
logger.info(f"Joined guild: {guild.name} (ID: {guild.id})")
|
||||||
|
|
||||||
|
if self.guild_config:
|
||||||
|
await self.guild_config.create_guild(guild)
|
||||||
|
|
||||||
|
async def on_guild_remove(self, guild: discord.Guild) -> None:
|
||||||
|
"""Called when the bot is removed from a guild."""
|
||||||
|
logger.info(f"Removed from guild: {guild.name} (ID: {guild.id})")
|
||||||
1
src/guardden/cogs/__init__.py
Normal file
1
src/guardden/cogs/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Discord cogs for GuardDen."""
|
||||||
255
src/guardden/cogs/admin.py
Normal file
255
src/guardden/cogs/admin.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
"""Admin commands for bot configuration."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from guardden.bot import GuardDen
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Admin(commands.Cog):
|
||||||
|
"""Administrative commands for bot configuration."""
|
||||||
|
|
||||||
|
def __init__(self, bot: GuardDen) -> None:
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
async def cog_check(self, ctx: commands.Context) -> bool:
|
||||||
|
"""Ensure only administrators can use these commands."""
|
||||||
|
if not ctx.guild:
|
||||||
|
return False
|
||||||
|
return ctx.author.guild_permissions.administrator
|
||||||
|
|
||||||
|
@commands.group(name="config", invoke_without_command=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def config(self, ctx: commands.Context) -> None:
|
||||||
|
"""View or modify bot configuration."""
|
||||||
|
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
await ctx.send("No configuration found. Run a config command to initialize.")
|
||||||
|
return
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title=f"Configuration for {ctx.guild.name}",
|
||||||
|
color=discord.Color.blue(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# General settings
|
||||||
|
embed.add_field(name="Prefix", value=f"`{config.prefix}`", inline=True)
|
||||||
|
embed.add_field(name="Locale", value=config.locale, inline=True)
|
||||||
|
embed.add_field(name="\u200b", value="\u200b", inline=True)
|
||||||
|
|
||||||
|
# Channels
|
||||||
|
log_ch = ctx.guild.get_channel(config.log_channel_id) if config.log_channel_id else None
|
||||||
|
mod_log_ch = (
|
||||||
|
ctx.guild.get_channel(config.mod_log_channel_id) if config.mod_log_channel_id else None
|
||||||
|
)
|
||||||
|
welcome_ch = (
|
||||||
|
ctx.guild.get_channel(config.welcome_channel_id) if config.welcome_channel_id else None
|
||||||
|
)
|
||||||
|
|
||||||
|
embed.add_field(
|
||||||
|
name="Log Channel", value=log_ch.mention if log_ch else "Not set", inline=True
|
||||||
|
)
|
||||||
|
embed.add_field(
|
||||||
|
name="Mod Log Channel",
|
||||||
|
value=mod_log_ch.mention if mod_log_ch else "Not set",
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
embed.add_field(
|
||||||
|
name="Welcome Channel",
|
||||||
|
value=welcome_ch.mention if welcome_ch else "Not set",
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Features
|
||||||
|
features = []
|
||||||
|
if config.automod_enabled:
|
||||||
|
features.append("AutoMod")
|
||||||
|
if config.anti_spam_enabled:
|
||||||
|
features.append("Anti-Spam")
|
||||||
|
if config.link_filter_enabled:
|
||||||
|
features.append("Link Filter")
|
||||||
|
if config.ai_moderation_enabled:
|
||||||
|
features.append("AI Moderation")
|
||||||
|
if config.verification_enabled:
|
||||||
|
features.append("Verification")
|
||||||
|
|
||||||
|
embed.add_field(
|
||||||
|
name="Enabled Features",
|
||||||
|
value=", ".join(features) if features else "None",
|
||||||
|
inline=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
await ctx.send(embed=embed)
|
||||||
|
|
||||||
|
@config.command(name="prefix")
|
||||||
|
@commands.guild_only()
|
||||||
|
async def config_prefix(self, ctx: commands.Context, prefix: str) -> None:
|
||||||
|
"""Set the command prefix for this server."""
|
||||||
|
if not prefix or not prefix.strip():
|
||||||
|
await ctx.send("Prefix cannot be empty or whitespace only.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(prefix) > 10:
|
||||||
|
await ctx.send("Prefix must be 10 characters or less.")
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.bot.guild_config.update_settings(ctx.guild.id, prefix=prefix)
|
||||||
|
await ctx.send(f"Command prefix set to `{prefix}`")
|
||||||
|
|
||||||
|
@config.command(name="logchannel")
|
||||||
|
@commands.guild_only()
|
||||||
|
async def config_log_channel(
|
||||||
|
self, ctx: commands.Context, channel: discord.TextChannel | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Set the channel for general event logs."""
|
||||||
|
channel_id = channel.id if channel else None
|
||||||
|
await self.bot.guild_config.update_settings(ctx.guild.id, log_channel_id=channel_id)
|
||||||
|
|
||||||
|
if channel:
|
||||||
|
await ctx.send(f"Log channel set to {channel.mention}")
|
||||||
|
else:
|
||||||
|
await ctx.send("Log channel has been disabled.")
|
||||||
|
|
||||||
|
@config.command(name="modlogchannel")
|
||||||
|
@commands.guild_only()
|
||||||
|
async def config_mod_log_channel(
|
||||||
|
self, ctx: commands.Context, channel: discord.TextChannel | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Set the channel for moderation action logs."""
|
||||||
|
channel_id = channel.id if channel else None
|
||||||
|
await self.bot.guild_config.update_settings(ctx.guild.id, mod_log_channel_id=channel_id)
|
||||||
|
|
||||||
|
if channel:
|
||||||
|
await ctx.send(f"Moderation log channel set to {channel.mention}")
|
||||||
|
else:
|
||||||
|
await ctx.send("Moderation log channel has been disabled.")
|
||||||
|
|
||||||
|
@config.command(name="welcomechannel")
|
||||||
|
@commands.guild_only()
|
||||||
|
async def config_welcome_channel(
|
||||||
|
self, ctx: commands.Context, channel: discord.TextChannel | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Set the welcome channel for new members."""
|
||||||
|
channel_id = channel.id if channel else None
|
||||||
|
await self.bot.guild_config.update_settings(ctx.guild.id, welcome_channel_id=channel_id)
|
||||||
|
|
||||||
|
if channel:
|
||||||
|
await ctx.send(f"Welcome channel set to {channel.mention}")
|
||||||
|
else:
|
||||||
|
await ctx.send("Welcome channel has been disabled.")
|
||||||
|
|
||||||
|
@config.command(name="muterole")
|
||||||
|
@commands.guild_only()
|
||||||
|
async def config_mute_role(
|
||||||
|
self, ctx: commands.Context, role: discord.Role | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Set the role to assign when muting members."""
|
||||||
|
role_id = role.id if role else None
|
||||||
|
await self.bot.guild_config.update_settings(ctx.guild.id, mute_role_id=role_id)
|
||||||
|
|
||||||
|
if role:
|
||||||
|
await ctx.send(f"Mute role set to {role.mention}")
|
||||||
|
else:
|
||||||
|
await ctx.send("Mute role has been cleared.")
|
||||||
|
|
||||||
|
@config.command(name="automod")
|
||||||
|
@commands.guild_only()
|
||||||
|
async def config_automod(self, ctx: commands.Context, enabled: bool) -> None:
|
||||||
|
"""Enable or disable automod features."""
|
||||||
|
await self.bot.guild_config.update_settings(ctx.guild.id, automod_enabled=enabled)
|
||||||
|
status = "enabled" if enabled else "disabled"
|
||||||
|
await ctx.send(f"AutoMod has been {status}.")
|
||||||
|
|
||||||
|
@config.command(name="antispam")
|
||||||
|
@commands.guild_only()
|
||||||
|
async def config_antispam(self, ctx: commands.Context, enabled: bool) -> None:
|
||||||
|
"""Enable or disable anti-spam protection."""
|
||||||
|
await self.bot.guild_config.update_settings(ctx.guild.id, anti_spam_enabled=enabled)
|
||||||
|
status = "enabled" if enabled else "disabled"
|
||||||
|
await ctx.send(f"Anti-spam has been {status}.")
|
||||||
|
|
||||||
|
@config.command(name="linkfilter")
|
||||||
|
@commands.guild_only()
|
||||||
|
async def config_linkfilter(self, ctx: commands.Context, enabled: bool) -> None:
|
||||||
|
"""Enable or disable link filtering."""
|
||||||
|
await self.bot.guild_config.update_settings(ctx.guild.id, link_filter_enabled=enabled)
|
||||||
|
status = "enabled" if enabled else "disabled"
|
||||||
|
await ctx.send(f"Link filter has been {status}.")
|
||||||
|
|
||||||
|
@commands.group(name="bannedwords", aliases=["bw"], invoke_without_command=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def banned_words(self, ctx: commands.Context) -> None:
|
||||||
|
"""Manage banned words list."""
|
||||||
|
words = await self.bot.guild_config.get_banned_words(ctx.guild.id)
|
||||||
|
|
||||||
|
if not words:
|
||||||
|
await ctx.send("No banned words configured.")
|
||||||
|
return
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Banned Words",
|
||||||
|
color=discord.Color.red(),
|
||||||
|
)
|
||||||
|
|
||||||
|
for word in words[:25]: # Discord embed limit
|
||||||
|
word_type = "Regex" if word.is_regex else "Text"
|
||||||
|
embed.add_field(
|
||||||
|
name=f"#{word.id}: {word.pattern[:30]}",
|
||||||
|
value=f"Type: {word_type} | Action: {word.action}",
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(words) > 25:
|
||||||
|
embed.set_footer(text=f"Showing 25 of {len(words)} banned words")
|
||||||
|
|
||||||
|
await ctx.send(embed=embed)
|
||||||
|
|
||||||
|
@banned_words.command(name="add")
|
||||||
|
@commands.guild_only()
|
||||||
|
async def banned_words_add(
|
||||||
|
self,
|
||||||
|
ctx: commands.Context,
|
||||||
|
pattern: str,
|
||||||
|
action: Literal["delete", "warn", "strike"] = "delete",
|
||||||
|
is_regex: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Add a banned word or pattern."""
|
||||||
|
word = await self.bot.guild_config.add_banned_word(
|
||||||
|
guild_id=ctx.guild.id,
|
||||||
|
pattern=pattern,
|
||||||
|
added_by=ctx.author.id,
|
||||||
|
is_regex=is_regex,
|
||||||
|
action=action,
|
||||||
|
)
|
||||||
|
|
||||||
|
word_type = "regex pattern" if is_regex else "word"
|
||||||
|
await ctx.send(f"Added banned {word_type}: `{pattern}` (ID: {word.id}, Action: {action})")
|
||||||
|
|
||||||
|
@banned_words.command(name="remove", aliases=["delete"])
|
||||||
|
@commands.guild_only()
|
||||||
|
async def banned_words_remove(self, ctx: commands.Context, word_id: int) -> None:
|
||||||
|
"""Remove a banned word by ID."""
|
||||||
|
success = await self.bot.guild_config.remove_banned_word(ctx.guild.id, word_id)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
await ctx.send(f"Removed banned word #{word_id}")
|
||||||
|
else:
|
||||||
|
await ctx.send(f"Banned word #{word_id} not found.")
|
||||||
|
|
||||||
|
@commands.command(name="sync")
|
||||||
|
@commands.is_owner()
|
||||||
|
async def sync_commands(self, ctx: commands.Context) -> None:
|
||||||
|
"""Sync slash commands (bot owner only)."""
|
||||||
|
await self.bot.tree.sync()
|
||||||
|
await ctx.send("Slash commands synced.")
|
||||||
|
|
||||||
|
|
||||||
|
async def setup(bot: GuardDen) -> None:
|
||||||
|
"""Load the Admin cog."""
|
||||||
|
await bot.add_cog(Admin(bot))
|
||||||
366
src/guardden/cogs/ai_moderation.py
Normal file
366
src/guardden/cogs/ai_moderation.py
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
"""AI-powered moderation cog."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from collections import deque
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from guardden.bot import GuardDen
|
||||||
|
from guardden.services.ai.base import ContentCategory, ModerationResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# URL pattern for extraction
|
||||||
|
URL_PATTERN = re.compile(
|
||||||
|
r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+[^\s]*",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AIModeration(commands.Cog):
|
||||||
|
"""AI-powered content moderation."""
|
||||||
|
|
||||||
|
def __init__(self, bot: GuardDen) -> None:
|
||||||
|
self.bot = bot
|
||||||
|
# Track recently analyzed messages to avoid duplicates (deque auto-removes oldest)
|
||||||
|
self._analyzed_messages: deque[int] = deque(maxlen=1000)
|
||||||
|
|
||||||
|
def _should_analyze(self, message: discord.Message) -> bool:
|
||||||
|
"""Determine if a message should be analyzed by AI."""
|
||||||
|
# Skip if already analyzed
|
||||||
|
if message.id in self._analyzed_messages:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip short messages
|
||||||
|
if len(message.content) < 20 and not message.attachments:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Skip messages from bots
|
||||||
|
if message.author.bot:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _track_message(self, message_id: int) -> None:
|
||||||
|
"""Track that a message has been analyzed."""
|
||||||
|
self._analyzed_messages.append(message_id)
|
||||||
|
|
||||||
|
async def _handle_ai_result(
|
||||||
|
self,
|
||||||
|
message: discord.Message,
|
||||||
|
result: ModerationResult,
|
||||||
|
analysis_type: str,
|
||||||
|
) -> None:
|
||||||
|
"""Handle the result of AI analysis."""
|
||||||
|
if not result.is_flagged:
|
||||||
|
return
|
||||||
|
|
||||||
|
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||||
|
if not config:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if severity meets threshold based on sensitivity
|
||||||
|
# Higher sensitivity = lower threshold needed to trigger
|
||||||
|
threshold = 100 - config.ai_sensitivity # e.g., sensitivity 70 = threshold 30
|
||||||
|
if result.severity < threshold:
|
||||||
|
logger.debug(
|
||||||
|
f"AI flagged content but below threshold: "
|
||||||
|
f"severity={result.severity}, threshold={threshold}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Determine action based on suggested action and severity
|
||||||
|
should_delete = result.suggested_action in ("delete", "timeout", "ban")
|
||||||
|
should_timeout = result.suggested_action in ("timeout", "ban") and result.severity > 70
|
||||||
|
|
||||||
|
# Delete message if needed
|
||||||
|
if should_delete:
|
||||||
|
try:
|
||||||
|
await message.delete()
|
||||||
|
except discord.Forbidden:
|
||||||
|
logger.warning(f"Cannot delete message: missing permissions")
|
||||||
|
except discord.NotFound:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Timeout user for severe violations
|
||||||
|
if should_timeout and isinstance(message.author, discord.Member):
|
||||||
|
timeout_duration = 300 if result.severity < 90 else 3600 # 5 min or 1 hour
|
||||||
|
try:
|
||||||
|
await message.author.timeout(
|
||||||
|
timedelta(seconds=timeout_duration),
|
||||||
|
reason=f"AI Moderation: {result.explanation[:100]}",
|
||||||
|
)
|
||||||
|
except discord.Forbidden:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Log to mod channel
|
||||||
|
await self._log_ai_action(message, result, analysis_type)
|
||||||
|
|
||||||
|
# Notify user
|
||||||
|
try:
|
||||||
|
embed = discord.Embed(
|
||||||
|
title=f"Message Flagged in {message.guild.name}",
|
||||||
|
description=result.explanation,
|
||||||
|
color=discord.Color.red(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
embed.add_field(
|
||||||
|
name="Categories",
|
||||||
|
value=", ".join(cat.value for cat in result.categories) or "Unknown",
|
||||||
|
)
|
||||||
|
if should_timeout:
|
||||||
|
embed.add_field(name="Action", value="You have been timed out")
|
||||||
|
await message.author.send(embed=embed)
|
||||||
|
except discord.Forbidden:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _log_ai_action(
|
||||||
|
self,
|
||||||
|
message: discord.Message,
|
||||||
|
result: ModerationResult,
|
||||||
|
analysis_type: str,
|
||||||
|
) -> None:
|
||||||
|
"""Log an AI moderation action."""
|
||||||
|
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||||
|
if not config or not config.mod_log_channel_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
channel = message.guild.get_channel(config.mod_log_channel_id)
|
||||||
|
if not channel or not isinstance(channel, discord.TextChannel):
|
||||||
|
return
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title=f"AI Moderation - {analysis_type}",
|
||||||
|
color=discord.Color.red(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
embed.set_author(
|
||||||
|
name=str(message.author),
|
||||||
|
icon_url=message.author.display_avatar.url,
|
||||||
|
)
|
||||||
|
|
||||||
|
embed.add_field(name="Confidence", value=f"{result.confidence:.0%}", inline=True)
|
||||||
|
embed.add_field(name="Severity", value=f"{result.severity}/100", inline=True)
|
||||||
|
embed.add_field(name="Action", value=result.suggested_action, inline=True)
|
||||||
|
|
||||||
|
categories = ", ".join(cat.value for cat in result.categories)
|
||||||
|
embed.add_field(name="Categories", value=categories or "None", inline=False)
|
||||||
|
embed.add_field(name="Explanation", value=result.explanation[:500], inline=False)
|
||||||
|
|
||||||
|
if message.content:
|
||||||
|
content = (
|
||||||
|
message.content[:500] + "..." if len(message.content) > 500 else message.content
|
||||||
|
)
|
||||||
|
embed.add_field(name="Content", value=f"```{content}```", inline=False)
|
||||||
|
|
||||||
|
embed.set_footer(text=f"User ID: {message.author.id} | Channel: #{message.channel.name}")
|
||||||
|
|
||||||
|
await channel.send(embed=embed)
|
||||||
|
|
||||||
|
@commands.Cog.listener()
|
||||||
|
async def on_message(self, message: discord.Message) -> None:
|
||||||
|
"""Analyze messages with AI moderation."""
|
||||||
|
if not message.guild:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if AI moderation is enabled for this guild
|
||||||
|
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||||
|
if not config or not config.ai_moderation_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Skip users with manage_messages permission
|
||||||
|
if isinstance(message.author, discord.Member):
|
||||||
|
if message.author.guild_permissions.manage_messages:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self._should_analyze(message):
|
||||||
|
return
|
||||||
|
|
||||||
|
self._track_message(message.id)
|
||||||
|
|
||||||
|
# Analyze text content
|
||||||
|
if message.content and len(message.content) >= 20:
|
||||||
|
result = await self.bot.ai_provider.moderate_text(
|
||||||
|
content=message.content,
|
||||||
|
context=f"Discord server: {message.guild.name}, channel: {message.channel.name}",
|
||||||
|
sensitivity=config.ai_sensitivity,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.is_flagged:
|
||||||
|
await self._handle_ai_result(message, result, "Text Analysis")
|
||||||
|
return # Don't continue if already flagged
|
||||||
|
|
||||||
|
# Analyze images if NSFW detection is enabled (limit to 3 per message)
|
||||||
|
if config.nsfw_detection_enabled and message.attachments:
|
||||||
|
images_analyzed = 0
|
||||||
|
for attachment in message.attachments:
|
||||||
|
if images_analyzed >= 3:
|
||||||
|
break
|
||||||
|
if attachment.content_type and attachment.content_type.startswith("image/"):
|
||||||
|
images_analyzed += 1
|
||||||
|
image_result = await self.bot.ai_provider.analyze_image(
|
||||||
|
image_url=attachment.url,
|
||||||
|
sensitivity=config.ai_sensitivity,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
image_result.is_nsfw
|
||||||
|
or image_result.is_violent
|
||||||
|
or image_result.is_disturbing
|
||||||
|
):
|
||||||
|
# Convert to ModerationResult format
|
||||||
|
categories = []
|
||||||
|
if image_result.is_nsfw:
|
||||||
|
categories.append(ContentCategory.SEXUAL)
|
||||||
|
if image_result.is_violent:
|
||||||
|
categories.append(ContentCategory.VIOLENCE)
|
||||||
|
|
||||||
|
result = ModerationResult(
|
||||||
|
is_flagged=True,
|
||||||
|
confidence=image_result.confidence,
|
||||||
|
categories=categories,
|
||||||
|
explanation=image_result.description,
|
||||||
|
suggested_action="delete",
|
||||||
|
)
|
||||||
|
await self._handle_ai_result(message, result, "Image Analysis")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Analyze URLs for phishing
|
||||||
|
urls = URL_PATTERN.findall(message.content)
|
||||||
|
for url in urls[:3]: # Limit to first 3 URLs
|
||||||
|
phishing_result = await self.bot.ai_provider.analyze_phishing(
|
||||||
|
url=url,
|
||||||
|
message_content=message.content,
|
||||||
|
)
|
||||||
|
|
||||||
|
if phishing_result.is_phishing and phishing_result.confidence > 0.7:
|
||||||
|
result = ModerationResult(
|
||||||
|
is_flagged=True,
|
||||||
|
confidence=phishing_result.confidence,
|
||||||
|
categories=[ContentCategory.SCAM],
|
||||||
|
explanation=phishing_result.explanation,
|
||||||
|
suggested_action="delete",
|
||||||
|
)
|
||||||
|
await self._handle_ai_result(message, result, "Phishing Detection")
|
||||||
|
return
|
||||||
|
|
||||||
|
@commands.group(name="ai", invoke_without_command=True)
|
||||||
|
@commands.has_permissions(administrator=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def ai_cmd(self, ctx: commands.Context) -> None:
|
||||||
|
"""View AI moderation settings."""
|
||||||
|
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="AI Moderation Settings",
|
||||||
|
color=discord.Color.blue(),
|
||||||
|
)
|
||||||
|
|
||||||
|
embed.add_field(
|
||||||
|
name="AI Moderation",
|
||||||
|
value="✅ Enabled" if config and config.ai_moderation_enabled else "❌ Disabled",
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
embed.add_field(
|
||||||
|
name="NSFW Detection",
|
||||||
|
value="✅ Enabled" if config and config.nsfw_detection_enabled else "❌ Disabled",
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
embed.add_field(
|
||||||
|
name="Sensitivity",
|
||||||
|
value=f"{config.ai_sensitivity}/100" if config else "50/100",
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
embed.add_field(
|
||||||
|
name="AI Provider",
|
||||||
|
value=self.bot.settings.ai_provider.capitalize(),
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await ctx.send(embed=embed)
|
||||||
|
|
||||||
|
@ai_cmd.command(name="enable")
|
||||||
|
@commands.has_permissions(administrator=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def ai_enable(self, ctx: commands.Context) -> None:
|
||||||
|
"""Enable AI moderation."""
|
||||||
|
if self.bot.settings.ai_provider == "none":
|
||||||
|
await ctx.send(
|
||||||
|
"AI moderation is not configured. Set `GUARDDEN_AI_PROVIDER` and API key."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.bot.guild_config.update_settings(ctx.guild.id, ai_moderation_enabled=True)
|
||||||
|
await ctx.send("✅ AI moderation enabled.")
|
||||||
|
|
||||||
|
@ai_cmd.command(name="disable")
|
||||||
|
@commands.has_permissions(administrator=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def ai_disable(self, ctx: commands.Context) -> None:
|
||||||
|
"""Disable AI moderation."""
|
||||||
|
await self.bot.guild_config.update_settings(ctx.guild.id, ai_moderation_enabled=False)
|
||||||
|
await ctx.send("❌ AI moderation disabled.")
|
||||||
|
|
||||||
|
@ai_cmd.command(name="sensitivity")
|
||||||
|
@commands.has_permissions(administrator=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def ai_sensitivity(self, ctx: commands.Context, level: int) -> None:
|
||||||
|
"""Set AI sensitivity level (0-100). Higher = more strict."""
|
||||||
|
if not 0 <= level <= 100:
|
||||||
|
await ctx.send("Sensitivity must be between 0 and 100.")
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.bot.guild_config.update_settings(ctx.guild.id, ai_sensitivity=level)
|
||||||
|
await ctx.send(f"AI sensitivity set to {level}/100.")
|
||||||
|
|
||||||
|
@ai_cmd.command(name="nsfw")
|
||||||
|
@commands.has_permissions(administrator=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def ai_nsfw(self, ctx: commands.Context, enabled: bool) -> None:
|
||||||
|
"""Enable or disable NSFW image detection."""
|
||||||
|
await self.bot.guild_config.update_settings(ctx.guild.id, nsfw_detection_enabled=enabled)
|
||||||
|
status = "enabled" if enabled else "disabled"
|
||||||
|
await ctx.send(f"NSFW detection {status}.")
|
||||||
|
|
||||||
|
@ai_cmd.command(name="analyze")
|
||||||
|
@commands.has_permissions(administrator=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def ai_analyze(self, ctx: commands.Context, *, text: str) -> None:
|
||||||
|
"""Test AI analysis on text (does not take action)."""
|
||||||
|
if self.bot.settings.ai_provider == "none":
|
||||||
|
await ctx.send("AI moderation is not configured.")
|
||||||
|
return
|
||||||
|
|
||||||
|
async with ctx.typing():
|
||||||
|
result = await self.bot.ai_provider.moderate_text(
|
||||||
|
content=text,
|
||||||
|
context=f"Test analysis in {ctx.guild.name}",
|
||||||
|
sensitivity=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="AI Analysis Result",
|
||||||
|
color=discord.Color.red() if result.is_flagged else discord.Color.green(),
|
||||||
|
)
|
||||||
|
|
||||||
|
embed.add_field(name="Flagged", value="Yes" if result.is_flagged else "No", inline=True)
|
||||||
|
embed.add_field(name="Confidence", value=f"{result.confidence:.0%}", inline=True)
|
||||||
|
embed.add_field(name="Severity", value=f"{result.severity}/100", inline=True)
|
||||||
|
embed.add_field(name="Suggested Action", value=result.suggested_action, inline=True)
|
||||||
|
|
||||||
|
if result.categories:
|
||||||
|
categories = ", ".join(cat.value for cat in result.categories)
|
||||||
|
embed.add_field(name="Categories", value=categories, inline=False)
|
||||||
|
|
||||||
|
if result.explanation:
|
||||||
|
embed.add_field(name="Explanation", value=result.explanation[:1000], inline=False)
|
||||||
|
|
||||||
|
await ctx.send(embed=embed)
|
||||||
|
|
||||||
|
|
||||||
|
async def setup(bot: GuardDen) -> None:
|
||||||
|
"""Load the AI Moderation cog."""
|
||||||
|
await bot.add_cog(AIModeration(bot))
|
||||||
267
src/guardden/cogs/automod.py
Normal file
267
src/guardden/cogs/automod.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""Automod cog for automatic content moderation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from guardden.bot import GuardDen
|
||||||
|
from guardden.services.automod import AutomodResult, AutomodService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Automod(commands.Cog):
|
||||||
|
"""Automatic content moderation."""
|
||||||
|
|
||||||
|
def __init__(self, bot: GuardDen) -> None:
|
||||||
|
self.bot = bot
|
||||||
|
self.automod = AutomodService()
|
||||||
|
|
||||||
|
async def _handle_violation(
|
||||||
|
self,
|
||||||
|
message: discord.Message,
|
||||||
|
result: AutomodResult,
|
||||||
|
) -> None:
|
||||||
|
"""Handle an automod violation."""
|
||||||
|
# Delete the message
|
||||||
|
if result.should_delete:
|
||||||
|
try:
|
||||||
|
await message.delete()
|
||||||
|
except discord.Forbidden:
|
||||||
|
logger.warning(f"Cannot delete message in {message.guild}: missing permissions")
|
||||||
|
except discord.NotFound:
|
||||||
|
pass # Already deleted
|
||||||
|
|
||||||
|
# Apply timeout
|
||||||
|
if result.should_timeout and result.timeout_duration > 0:
|
||||||
|
try:
|
||||||
|
await message.author.timeout(
|
||||||
|
timedelta(seconds=result.timeout_duration),
|
||||||
|
reason=f"Automod: {result.reason}",
|
||||||
|
)
|
||||||
|
except discord.Forbidden:
|
||||||
|
logger.warning(f"Cannot timeout {message.author}: missing permissions")
|
||||||
|
|
||||||
|
# Log the action
|
||||||
|
await self._log_automod_action(message, result)
|
||||||
|
|
||||||
|
# Notify the user via DM
|
||||||
|
try:
|
||||||
|
embed = discord.Embed(
|
||||||
|
title=f"Message Removed in {message.guild.name}",
|
||||||
|
description=result.reason,
|
||||||
|
color=discord.Color.orange(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
if result.should_timeout:
|
||||||
|
embed.add_field(
|
||||||
|
name="Timeout",
|
||||||
|
value=f"You have been timed out for {result.timeout_duration} seconds.",
|
||||||
|
)
|
||||||
|
await message.author.send(embed=embed)
|
||||||
|
except discord.Forbidden:
|
||||||
|
pass # User has DMs disabled
|
||||||
|
|
||||||
|
async def _log_automod_action(
|
||||||
|
self,
|
||||||
|
message: discord.Message,
|
||||||
|
result: AutomodResult,
|
||||||
|
) -> None:
|
||||||
|
"""Log an automod action to the mod log channel."""
|
||||||
|
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||||
|
if not config or not config.mod_log_channel_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
channel = message.guild.get_channel(config.mod_log_channel_id)
|
||||||
|
if not channel or not isinstance(channel, discord.TextChannel):
|
||||||
|
return
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Automod Action",
|
||||||
|
color=discord.Color.orange(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
embed.set_author(
|
||||||
|
name=str(message.author),
|
||||||
|
icon_url=message.author.display_avatar.url,
|
||||||
|
)
|
||||||
|
embed.add_field(name="Filter", value=result.matched_filter, inline=True)
|
||||||
|
embed.add_field(name="Channel", value=message.channel.mention, inline=True)
|
||||||
|
embed.add_field(name="Reason", value=result.reason, inline=False)
|
||||||
|
|
||||||
|
if message.content:
|
||||||
|
content = (
|
||||||
|
message.content[:500] + "..." if len(message.content) > 500 else message.content
|
||||||
|
)
|
||||||
|
embed.add_field(name="Message Content", value=f"```{content}```", inline=False)
|
||||||
|
|
||||||
|
actions = []
|
||||||
|
if result.should_delete:
|
||||||
|
actions.append("Message deleted")
|
||||||
|
if result.should_warn:
|
||||||
|
actions.append("User warned")
|
||||||
|
if result.should_strike:
|
||||||
|
actions.append("Strike added")
|
||||||
|
if result.should_timeout:
|
||||||
|
actions.append(f"Timeout ({result.timeout_duration}s)")
|
||||||
|
|
||||||
|
embed.add_field(name="Actions Taken", value=", ".join(actions) or "None", inline=False)
|
||||||
|
embed.set_footer(text=f"User ID: {message.author.id}")
|
||||||
|
|
||||||
|
await channel.send(embed=embed)
|
||||||
|
|
||||||
|
@commands.Cog.listener()
|
||||||
|
async def on_message(self, message: discord.Message) -> None:
|
||||||
|
"""Check all messages for automod violations."""
|
||||||
|
# Ignore DMs, bots, and empty messages
|
||||||
|
if not message.guild or message.author.bot or not message.content:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ignore users with manage_messages permission
|
||||||
|
if isinstance(message.author, discord.Member):
|
||||||
|
if message.author.guild_permissions.manage_messages:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get guild config
|
||||||
|
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||||
|
if not config or not config.automod_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
result: AutomodResult | None = None
|
||||||
|
|
||||||
|
# Check banned words
|
||||||
|
banned_words = await self.bot.guild_config.get_banned_words(message.guild.id)
|
||||||
|
if banned_words:
|
||||||
|
result = self.automod.check_banned_words(message.content, banned_words)
|
||||||
|
|
||||||
|
# Check scam links (if link filter enabled)
|
||||||
|
if not result and config.link_filter_enabled:
|
||||||
|
result = self.automod.check_scam_links(message.content)
|
||||||
|
|
||||||
|
# Check spam
|
||||||
|
if not result and config.anti_spam_enabled:
|
||||||
|
result = self.automod.check_spam(message, anti_spam_enabled=True)
|
||||||
|
|
||||||
|
# Check invite links (if link filter enabled)
|
||||||
|
if not result and config.link_filter_enabled:
|
||||||
|
result = self.automod.check_invite_links(message.content, allow_invites=False)
|
||||||
|
|
||||||
|
# Handle violation if found
|
||||||
|
if result:
|
||||||
|
logger.info(
|
||||||
|
f"Automod triggered in {message.guild.name}: "
|
||||||
|
f"{result.matched_filter} by {message.author}"
|
||||||
|
)
|
||||||
|
await self._handle_violation(message, result)
|
||||||
|
|
||||||
|
@commands.Cog.listener()
|
||||||
|
async def on_message_edit(self, before: discord.Message, after: discord.Message) -> None:
|
||||||
|
"""Check edited messages for automod violations."""
|
||||||
|
# Only check if content changed
|
||||||
|
if before.content == after.content:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Reuse on_message logic
|
||||||
|
await self.on_message(after)
|
||||||
|
|
||||||
|
@commands.group(name="automod", invoke_without_command=True)
|
||||||
|
@commands.has_permissions(administrator=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def automod_cmd(self, ctx: commands.Context) -> None:
|
||||||
|
"""View automod status and configuration."""
|
||||||
|
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Automod Configuration",
|
||||||
|
color=discord.Color.blue(),
|
||||||
|
)
|
||||||
|
|
||||||
|
embed.add_field(
|
||||||
|
name="Automod Enabled",
|
||||||
|
value="✅ Yes" if config and config.automod_enabled else "❌ No",
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
embed.add_field(
|
||||||
|
name="Anti-Spam",
|
||||||
|
value="✅ Yes" if config and config.anti_spam_enabled else "❌ No",
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
embed.add_field(
|
||||||
|
name="Link Filter",
|
||||||
|
value="✅ Yes" if config and config.link_filter_enabled else "❌ No",
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show thresholds
|
||||||
|
embed.add_field(
|
||||||
|
name="Rate Limit",
|
||||||
|
value=f"{self.automod.message_rate_limit} msgs / {self.automod.message_rate_window}s",
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
embed.add_field(
|
||||||
|
name="Duplicate Threshold",
|
||||||
|
value=f"{self.automod.duplicate_threshold} same messages",
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
embed.add_field(
|
||||||
|
name="Mention Limit",
|
||||||
|
value=f"{self.automod.mention_limit} per message",
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
banned_words = await self.bot.guild_config.get_banned_words(ctx.guild.id)
|
||||||
|
embed.add_field(
|
||||||
|
name="Banned Words",
|
||||||
|
value=f"{len(banned_words)} configured",
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await ctx.send(embed=embed)
|
||||||
|
|
||||||
|
@automod_cmd.command(name="test")
|
||||||
|
@commands.has_permissions(administrator=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def automod_test(self, ctx: commands.Context, *, text: str) -> None:
|
||||||
|
"""Test a message against automod filters (does not take action)."""
|
||||||
|
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Check banned words
|
||||||
|
banned_words = await self.bot.guild_config.get_banned_words(ctx.guild.id)
|
||||||
|
result = self.automod.check_banned_words(text, banned_words)
|
||||||
|
if result:
|
||||||
|
results.append(f"**Banned Words**: {result.reason}")
|
||||||
|
|
||||||
|
# Check scam links
|
||||||
|
result = self.automod.check_scam_links(text)
|
||||||
|
if result:
|
||||||
|
results.append(f"**Scam Detection**: {result.reason}")
|
||||||
|
|
||||||
|
# Check invite links
|
||||||
|
result = self.automod.check_invite_links(text, allow_invites=False)
|
||||||
|
if result:
|
||||||
|
results.append(f"**Invite Links**: {result.reason}")
|
||||||
|
|
||||||
|
# Check caps
|
||||||
|
result = self.automod.check_all_caps(text)
|
||||||
|
if result:
|
||||||
|
results.append(f"**Excessive Caps**: {result.reason}")
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Automod Test Results",
|
||||||
|
color=discord.Color.red() if results else discord.Color.green(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if results:
|
||||||
|
embed.description = "\n".join(results)
|
||||||
|
else:
|
||||||
|
embed.description = "✅ No violations detected"
|
||||||
|
|
||||||
|
await ctx.send(embed=embed)
|
||||||
|
|
||||||
|
|
||||||
|
async def setup(bot: GuardDen) -> None:
|
||||||
|
"""Load the Automod cog."""
|
||||||
|
await bot.add_cog(Automod(bot))
|
||||||
237
src/guardden/cogs/events.py
Normal file
237
src/guardden/cogs/events.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
"""Event handlers for logging and monitoring."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from guardden.bot import GuardDen
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Events(commands.Cog):
|
||||||
|
"""Handles Discord events for logging and monitoring."""
|
||||||
|
|
||||||
|
def __init__(self, bot: GuardDen) -> None:
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
@commands.Cog.listener()
|
||||||
|
async def on_member_join(self, member: discord.Member) -> None:
|
||||||
|
"""Called when a member joins a guild."""
|
||||||
|
logger.debug(f"Member joined: {member} in {member.guild}")
|
||||||
|
|
||||||
|
config = await self.bot.guild_config.get_config(member.guild.id)
|
||||||
|
if not config or not config.log_channel_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
channel = member.guild.get_channel(config.log_channel_id)
|
||||||
|
if not channel or not isinstance(channel, discord.TextChannel):
|
||||||
|
return
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Member Joined",
|
||||||
|
description=f"{member.mention} ({member})",
|
||||||
|
color=discord.Color.green(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
embed.set_thumbnail(url=member.display_avatar.url)
|
||||||
|
embed.add_field(
|
||||||
|
name="Account Created", value=discord.utils.format_dt(member.created_at, "R")
|
||||||
|
)
|
||||||
|
embed.add_field(name="Member ID", value=str(member.id))
|
||||||
|
|
||||||
|
await channel.send(embed=embed)
|
||||||
|
|
||||||
|
@commands.Cog.listener()
|
||||||
|
async def on_member_remove(self, member: discord.Member) -> None:
|
||||||
|
"""Called when a member leaves a guild."""
|
||||||
|
logger.debug(f"Member left: {member} from {member.guild}")
|
||||||
|
|
||||||
|
config = await self.bot.guild_config.get_config(member.guild.id)
|
||||||
|
if not config or not config.log_channel_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
channel = member.guild.get_channel(config.log_channel_id)
|
||||||
|
if not channel or not isinstance(channel, discord.TextChannel):
|
||||||
|
return
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Member Left",
|
||||||
|
description=f"{member} ({member.id})",
|
||||||
|
color=discord.Color.orange(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
embed.set_thumbnail(url=member.display_avatar.url)
|
||||||
|
|
||||||
|
if member.joined_at:
|
||||||
|
embed.add_field(name="Joined", value=discord.utils.format_dt(member.joined_at, "R"))
|
||||||
|
|
||||||
|
roles = [r.mention for r in member.roles if r != member.guild.default_role]
|
||||||
|
if roles:
|
||||||
|
embed.add_field(name="Roles", value=", ".join(roles[:10]), inline=False)
|
||||||
|
|
||||||
|
await channel.send(embed=embed)
|
||||||
|
|
||||||
|
@commands.Cog.listener()
|
||||||
|
async def on_message_delete(self, message: discord.Message) -> None:
|
||||||
|
"""Called when a message is deleted."""
|
||||||
|
if message.author.bot or not message.guild:
|
||||||
|
return
|
||||||
|
|
||||||
|
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||||
|
if not config or not config.log_channel_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
channel = message.guild.get_channel(config.log_channel_id)
|
||||||
|
if not channel or not isinstance(channel, discord.TextChannel):
|
||||||
|
return
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Message Deleted",
|
||||||
|
description=f"In {message.channel.mention}",
|
||||||
|
color=discord.Color.red(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
embed.set_author(name=str(message.author), icon_url=message.author.display_avatar.url)
|
||||||
|
|
||||||
|
if message.content:
|
||||||
|
content = message.content[:1024] if len(message.content) > 1024 else message.content
|
||||||
|
embed.add_field(name="Content", value=content, inline=False)
|
||||||
|
|
||||||
|
if message.attachments:
|
||||||
|
attachments = "\n".join(a.filename for a in message.attachments)
|
||||||
|
embed.add_field(name="Attachments", value=attachments, inline=False)
|
||||||
|
|
||||||
|
embed.set_footer(text=f"Author ID: {message.author.id} | Message ID: {message.id}")
|
||||||
|
|
||||||
|
await channel.send(embed=embed)
|
||||||
|
|
||||||
|
@commands.Cog.listener()
|
||||||
|
async def on_message_edit(self, before: discord.Message, after: discord.Message) -> None:
|
||||||
|
"""Called when a message is edited."""
|
||||||
|
if before.author.bot or not before.guild:
|
||||||
|
return
|
||||||
|
|
||||||
|
if before.content == after.content:
|
||||||
|
return
|
||||||
|
|
||||||
|
config = await self.bot.guild_config.get_config(before.guild.id)
|
||||||
|
if not config or not config.log_channel_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
channel = before.guild.get_channel(config.log_channel_id)
|
||||||
|
if not channel or not isinstance(channel, discord.TextChannel):
|
||||||
|
return
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Message Edited",
|
||||||
|
description=f"In {before.channel.mention} | [Jump to message]({after.jump_url})",
|
||||||
|
color=discord.Color.blue(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
embed.set_author(name=str(before.author), icon_url=before.author.display_avatar.url)
|
||||||
|
|
||||||
|
before_content = before.content[:1024] if len(before.content) > 1024 else before.content
|
||||||
|
after_content = after.content[:1024] if len(after.content) > 1024 else after.content
|
||||||
|
|
||||||
|
embed.add_field(name="Before", value=before_content or "*empty*", inline=False)
|
||||||
|
embed.add_field(name="After", value=after_content or "*empty*", inline=False)
|
||||||
|
embed.set_footer(text=f"Author ID: {before.author.id}")
|
||||||
|
|
||||||
|
await channel.send(embed=embed)
|
||||||
|
|
||||||
|
@commands.Cog.listener()
|
||||||
|
async def on_voice_state_update(
|
||||||
|
self,
|
||||||
|
member: discord.Member,
|
||||||
|
before: discord.VoiceState,
|
||||||
|
after: discord.VoiceState,
|
||||||
|
) -> None:
|
||||||
|
"""Called when a member's voice state changes."""
|
||||||
|
if member.bot:
|
||||||
|
return
|
||||||
|
|
||||||
|
config = await self.bot.guild_config.get_config(member.guild.id)
|
||||||
|
if not config or not config.log_channel_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
channel = member.guild.get_channel(config.log_channel_id)
|
||||||
|
if not channel or not isinstance(channel, discord.TextChannel):
|
||||||
|
return
|
||||||
|
|
||||||
|
embed = None
|
||||||
|
|
||||||
|
if before.channel is None and after.channel is not None:
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Voice Channel Joined",
|
||||||
|
description=f"{member.mention} joined {after.channel.mention}",
|
||||||
|
color=discord.Color.green(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
elif before.channel is not None and after.channel is None:
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Voice Channel Left",
|
||||||
|
description=f"{member.mention} left {before.channel.mention}",
|
||||||
|
color=discord.Color.orange(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
elif before.channel != after.channel and before.channel and after.channel:
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Voice Channel Moved",
|
||||||
|
description=f"{member.mention} moved from {before.channel.mention} to {after.channel.mention}",
|
||||||
|
color=discord.Color.blue(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
if embed:
|
||||||
|
embed.set_author(name=str(member), icon_url=member.display_avatar.url)
|
||||||
|
await channel.send(embed=embed)
|
||||||
|
|
||||||
|
@commands.Cog.listener()
|
||||||
|
async def on_member_ban(self, guild: discord.Guild, user: discord.User) -> None:
|
||||||
|
"""Called when a user is banned."""
|
||||||
|
config = await self.bot.guild_config.get_config(guild.id)
|
||||||
|
if not config or not config.mod_log_channel_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
channel = guild.get_channel(config.mod_log_channel_id)
|
||||||
|
if not channel or not isinstance(channel, discord.TextChannel):
|
||||||
|
return
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Member Banned",
|
||||||
|
description=f"{user} ({user.id})",
|
||||||
|
color=discord.Color.dark_red(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
embed.set_thumbnail(url=user.display_avatar.url)
|
||||||
|
|
||||||
|
await channel.send(embed=embed)
|
||||||
|
|
||||||
|
@commands.Cog.listener()
|
||||||
|
async def on_member_unban(self, guild: discord.Guild, user: discord.User) -> None:
|
||||||
|
"""Called when a user is unbanned."""
|
||||||
|
config = await self.bot.guild_config.get_config(guild.id)
|
||||||
|
if not config or not config.mod_log_channel_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
channel = guild.get_channel(config.mod_log_channel_id)
|
||||||
|
if not channel or not isinstance(channel, discord.TextChannel):
|
||||||
|
return
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Member Unbanned",
|
||||||
|
description=f"{user} ({user.id})",
|
||||||
|
color=discord.Color.green(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
embed.set_thumbnail(url=user.display_avatar.url)
|
||||||
|
|
||||||
|
await channel.send(embed=embed)
|
||||||
|
|
||||||
|
|
||||||
|
async def setup(bot: GuardDen) -> None:
|
||||||
|
"""Load the Events cog."""
|
||||||
|
await bot.add_cog(Events(bot))
|
||||||
466
src/guardden/cogs/moderation.py
Normal file
466
src/guardden/cogs/moderation.py
Normal file
@@ -0,0 +1,466 @@
|
|||||||
|
"""Moderation commands and automod features."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
|
||||||
|
from guardden.bot import GuardDen
|
||||||
|
from guardden.models import ModerationLog, Strike
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_duration(duration_str: str) -> timedelta | None:
|
||||||
|
"""Parse a duration string like '1h', '30m', '7d' into a timedelta."""
|
||||||
|
match = re.match(r"^(\d+)([smhdw])$", duration_str.lower())
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
|
||||||
|
amount = int(match.group(1))
|
||||||
|
unit = match.group(2)
|
||||||
|
|
||||||
|
units = {
|
||||||
|
"s": timedelta(seconds=amount),
|
||||||
|
"m": timedelta(minutes=amount),
|
||||||
|
"h": timedelta(hours=amount),
|
||||||
|
"d": timedelta(days=amount),
|
||||||
|
"w": timedelta(weeks=amount),
|
||||||
|
}
|
||||||
|
|
||||||
|
return units.get(unit)
|
||||||
|
|
||||||
|
|
||||||
|
class Moderation(commands.Cog):
|
||||||
|
"""Moderation commands for server management."""
|
||||||
|
|
||||||
|
def __init__(self, bot: GuardDen) -> None:
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
async def _log_action(
|
||||||
|
self,
|
||||||
|
guild: discord.Guild,
|
||||||
|
target: discord.Member | discord.User,
|
||||||
|
moderator: discord.Member | discord.User,
|
||||||
|
action: str,
|
||||||
|
reason: str | None = None,
|
||||||
|
duration: int | None = None,
|
||||||
|
channel: discord.TextChannel | None = None,
|
||||||
|
message: discord.Message | None = None,
|
||||||
|
is_automatic: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Log a moderation action to the database."""
|
||||||
|
expires_at = None
|
||||||
|
if duration:
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(seconds=duration)
|
||||||
|
|
||||||
|
async with self.bot.database.session() as session:
|
||||||
|
log_entry = ModerationLog(
|
||||||
|
guild_id=guild.id,
|
||||||
|
target_id=target.id,
|
||||||
|
target_name=str(target),
|
||||||
|
moderator_id=moderator.id,
|
||||||
|
moderator_name=str(moderator),
|
||||||
|
action=action,
|
||||||
|
reason=reason,
|
||||||
|
duration=duration,
|
||||||
|
expires_at=expires_at,
|
||||||
|
channel_id=channel.id if channel else None,
|
||||||
|
message_id=message.id if message else None,
|
||||||
|
message_content=message.content if message else None,
|
||||||
|
is_automatic=is_automatic,
|
||||||
|
)
|
||||||
|
session.add(log_entry)
|
||||||
|
|
||||||
|
async def _get_strike_count(self, guild_id: int, user_id: int) -> int:
|
||||||
|
"""Get the total active strike count for a user."""
|
||||||
|
async with self.bot.database.session() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(func.sum(Strike.points)).where(
|
||||||
|
Strike.guild_id == guild_id,
|
||||||
|
Strike.user_id == user_id,
|
||||||
|
Strike.is_active == True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
total = result.scalar()
|
||||||
|
return total or 0
|
||||||
|
|
||||||
|
async def _add_strike(
|
||||||
|
self,
|
||||||
|
guild: discord.Guild,
|
||||||
|
user: discord.Member,
|
||||||
|
moderator: discord.Member | discord.User,
|
||||||
|
reason: str,
|
||||||
|
points: int = 1,
|
||||||
|
) -> int:
|
||||||
|
"""Add a strike to a user and return their new total."""
|
||||||
|
async with self.bot.database.session() as session:
|
||||||
|
strike = Strike(
|
||||||
|
guild_id=guild.id,
|
||||||
|
user_id=user.id,
|
||||||
|
user_name=str(user),
|
||||||
|
moderator_id=moderator.id,
|
||||||
|
reason=reason,
|
||||||
|
points=points,
|
||||||
|
)
|
||||||
|
session.add(strike)
|
||||||
|
|
||||||
|
return await self._get_strike_count(guild.id, user.id)
|
||||||
|
|
||||||
|
@commands.command(name="warn")
|
||||||
|
@commands.has_permissions(kick_members=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def warn(
|
||||||
|
self, ctx: commands.Context, member: discord.Member, *, reason: str = "No reason provided"
|
||||||
|
) -> None:
|
||||||
|
"""Warn a member."""
|
||||||
|
if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner:
|
||||||
|
await ctx.send("You cannot warn someone with a higher or equal role.")
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._log_action(ctx.guild, member, ctx.author, "warn", reason)
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Warning Issued",
|
||||||
|
description=f"{member.mention} has been warned.",
|
||||||
|
color=discord.Color.yellow(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
embed.add_field(name="Reason", value=reason, inline=False)
|
||||||
|
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||||
|
|
||||||
|
await ctx.send(embed=embed)
|
||||||
|
|
||||||
|
# Try to DM the user
|
||||||
|
try:
|
||||||
|
dm_embed = discord.Embed(
|
||||||
|
title=f"Warning in {ctx.guild.name}",
|
||||||
|
description=f"You have been warned.",
|
||||||
|
color=discord.Color.yellow(),
|
||||||
|
)
|
||||||
|
dm_embed.add_field(name="Reason", value=reason)
|
||||||
|
await member.send(embed=dm_embed)
|
||||||
|
except discord.Forbidden:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@commands.command(name="strike")
|
||||||
|
@commands.has_permissions(kick_members=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def strike(
|
||||||
|
self,
|
||||||
|
ctx: commands.Context,
|
||||||
|
member: discord.Member,
|
||||||
|
points: int = 1,
|
||||||
|
*,
|
||||||
|
reason: str = "No reason provided",
|
||||||
|
) -> None:
|
||||||
|
"""Add a strike to a member."""
|
||||||
|
if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner:
|
||||||
|
await ctx.send("You cannot strike someone with a higher or equal role.")
|
||||||
|
return
|
||||||
|
|
||||||
|
total_strikes = await self._add_strike(ctx.guild, member, ctx.author, reason, points)
|
||||||
|
await self._log_action(ctx.guild, member, ctx.author, "strike", reason)
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Strike Added",
|
||||||
|
description=f"{member.mention} has received {points} strike(s).",
|
||||||
|
color=discord.Color.orange(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
embed.add_field(name="Reason", value=reason, inline=False)
|
||||||
|
embed.add_field(name="Total Strikes", value=str(total_strikes))
|
||||||
|
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||||
|
|
||||||
|
await ctx.send(embed=embed)
|
||||||
|
|
||||||
|
# Check for automatic actions based on strike thresholds
|
||||||
|
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||||
|
if config and config.strike_actions:
|
||||||
|
for threshold, action_config in sorted(
|
||||||
|
config.strike_actions.items(), key=lambda x: int(x[0]), reverse=True
|
||||||
|
):
|
||||||
|
if total_strikes >= int(threshold):
|
||||||
|
action = action_config.get("action")
|
||||||
|
if action == "ban":
|
||||||
|
await ctx.invoke(
|
||||||
|
self.ban, member=member, reason=f"Automatic: {total_strikes} strikes"
|
||||||
|
)
|
||||||
|
elif action == "kick":
|
||||||
|
await ctx.invoke(
|
||||||
|
self.kick, member=member, reason=f"Automatic: {total_strikes} strikes"
|
||||||
|
)
|
||||||
|
elif action == "timeout":
|
||||||
|
duration = action_config.get("duration", 3600)
|
||||||
|
await ctx.invoke(
|
||||||
|
self.timeout,
|
||||||
|
member=member,
|
||||||
|
duration=f"{duration}s",
|
||||||
|
reason=f"Automatic: {total_strikes} strikes",
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
@commands.command(name="strikes")
|
||||||
|
@commands.has_permissions(kick_members=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def strikes(self, ctx: commands.Context, member: discord.Member) -> None:
|
||||||
|
"""View strikes for a member."""
|
||||||
|
async with self.bot.database.session() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(Strike)
|
||||||
|
.where(
|
||||||
|
Strike.guild_id == ctx.guild.id,
|
||||||
|
Strike.user_id == member.id,
|
||||||
|
Strike.is_active == True,
|
||||||
|
)
|
||||||
|
.order_by(Strike.created_at.desc())
|
||||||
|
.limit(10)
|
||||||
|
)
|
||||||
|
user_strikes = result.scalars().all()
|
||||||
|
|
||||||
|
total = await self._get_strike_count(ctx.guild.id, member.id)
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title=f"Strikes for {member}",
|
||||||
|
description=f"Total active strikes: **{total}**",
|
||||||
|
color=discord.Color.orange(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_strikes:
|
||||||
|
for strike in user_strikes:
|
||||||
|
embed.add_field(
|
||||||
|
name=f"Strike #{strike.id} ({strike.points} pts)",
|
||||||
|
value=f"{strike.reason}\n*{strike.created_at.strftime('%Y-%m-%d')}*",
|
||||||
|
inline=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embed.description = f"{member.mention} has no active strikes."
|
||||||
|
|
||||||
|
await ctx.send(embed=embed)
|
||||||
|
|
||||||
|
@commands.command(name="timeout", aliases=["mute"])
|
||||||
|
@commands.has_permissions(moderate_members=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def timeout(
|
||||||
|
self,
|
||||||
|
ctx: commands.Context,
|
||||||
|
member: discord.Member,
|
||||||
|
duration: str = "1h",
|
||||||
|
*,
|
||||||
|
reason: str = "No reason provided",
|
||||||
|
) -> None:
|
||||||
|
"""Timeout a member (e.g., !timeout @user 1h Spamming)."""
|
||||||
|
if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner:
|
||||||
|
await ctx.send("You cannot timeout someone with a higher or equal role.")
|
||||||
|
return
|
||||||
|
|
||||||
|
delta = parse_duration(duration)
|
||||||
|
if not delta:
|
||||||
|
await ctx.send("Invalid duration. Use format like: 30m, 1h, 7d")
|
||||||
|
return
|
||||||
|
|
||||||
|
if delta > timedelta(days=28):
|
||||||
|
await ctx.send("Timeout duration cannot exceed 28 days.")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await member.timeout(delta, reason=f"{ctx.author}: {reason}")
|
||||||
|
except discord.Forbidden:
|
||||||
|
await ctx.send("I don't have permission to timeout this user.")
|
||||||
|
return
|
||||||
|
except discord.HTTPException as e:
|
||||||
|
await ctx.send(f"Failed to timeout user: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._log_action(
|
||||||
|
ctx.guild, member, ctx.author, "timeout", reason, int(delta.total_seconds())
|
||||||
|
)
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Member Timed Out",
|
||||||
|
description=f"{member.mention} has been timed out for {duration}.",
|
||||||
|
color=discord.Color.orange(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
embed.add_field(name="Reason", value=reason, inline=False)
|
||||||
|
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||||
|
|
||||||
|
await ctx.send(embed=embed)
|
||||||
|
|
||||||
|
@commands.command(name="untimeout", aliases=["unmute"])
|
||||||
|
@commands.has_permissions(moderate_members=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def untimeout(
|
||||||
|
self, ctx: commands.Context, member: discord.Member, *, reason: str = "No reason provided"
|
||||||
|
) -> None:
|
||||||
|
"""Remove timeout from a member."""
|
||||||
|
await member.timeout(None, reason=f"{ctx.author}: {reason}")
|
||||||
|
await self._log_action(ctx.guild, member, ctx.author, "unmute", reason)
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Timeout Removed",
|
||||||
|
description=f"{member.mention}'s timeout has been removed.",
|
||||||
|
color=discord.Color.green(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
embed.add_field(name="Reason", value=reason, inline=False)
|
||||||
|
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||||
|
|
||||||
|
await ctx.send(embed=embed)
|
||||||
|
|
||||||
|
@commands.command(name="kick")
|
||||||
|
@commands.has_permissions(kick_members=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def kick(
|
||||||
|
self, ctx: commands.Context, member: discord.Member, *, reason: str = "No reason provided"
|
||||||
|
) -> None:
|
||||||
|
"""Kick a member from the server."""
|
||||||
|
if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner:
|
||||||
|
await ctx.send("You cannot kick someone with a higher or equal role.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Try to DM the user before kicking
|
||||||
|
try:
|
||||||
|
dm_embed = discord.Embed(
|
||||||
|
title=f"Kicked from {ctx.guild.name}",
|
||||||
|
description=f"You have been kicked from the server.",
|
||||||
|
color=discord.Color.red(),
|
||||||
|
)
|
||||||
|
dm_embed.add_field(name="Reason", value=reason)
|
||||||
|
await member.send(embed=dm_embed)
|
||||||
|
except discord.Forbidden:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await member.kick(reason=f"{ctx.author}: {reason}")
|
||||||
|
await self._log_action(ctx.guild, member, ctx.author, "kick", reason)
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Member Kicked",
|
||||||
|
description=f"{member} has been kicked from the server.",
|
||||||
|
color=discord.Color.red(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
embed.add_field(name="Reason", value=reason, inline=False)
|
||||||
|
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||||
|
|
||||||
|
await ctx.send(embed=embed)
|
||||||
|
|
||||||
|
@commands.command(name="ban")
|
||||||
|
@commands.has_permissions(ban_members=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def ban(
|
||||||
|
self,
|
||||||
|
ctx: commands.Context,
|
||||||
|
member: discord.Member | discord.User,
|
||||||
|
*,
|
||||||
|
reason: str = "No reason provided",
|
||||||
|
) -> None:
|
||||||
|
"""Ban a member from the server."""
|
||||||
|
if isinstance(member, discord.Member):
|
||||||
|
if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner:
|
||||||
|
await ctx.send("You cannot ban someone with a higher or equal role.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Try to DM the user before banning
|
||||||
|
try:
|
||||||
|
dm_embed = discord.Embed(
|
||||||
|
title=f"Banned from {ctx.guild.name}",
|
||||||
|
description=f"You have been banned from the server.",
|
||||||
|
color=discord.Color.dark_red(),
|
||||||
|
)
|
||||||
|
dm_embed.add_field(name="Reason", value=reason)
|
||||||
|
await member.send(embed=dm_embed)
|
||||||
|
except discord.Forbidden:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await ctx.guild.ban(member, reason=f"{ctx.author}: {reason}", delete_message_days=0)
|
||||||
|
await self._log_action(ctx.guild, member, ctx.author, "ban", reason)
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Member Banned",
|
||||||
|
description=f"{member} has been banned from the server.",
|
||||||
|
color=discord.Color.dark_red(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
embed.add_field(name="Reason", value=reason, inline=False)
|
||||||
|
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||||
|
|
||||||
|
await ctx.send(embed=embed)
|
||||||
|
|
||||||
|
@commands.command(name="unban")
|
||||||
|
@commands.has_permissions(ban_members=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def unban(
|
||||||
|
self, ctx: commands.Context, user_id: int, *, reason: str = "No reason provided"
|
||||||
|
) -> None:
|
||||||
|
"""Unban a user by their ID."""
|
||||||
|
try:
|
||||||
|
user = await self.bot.fetch_user(user_id)
|
||||||
|
await ctx.guild.unban(user, reason=f"{ctx.author}: {reason}")
|
||||||
|
await self._log_action(ctx.guild, user, ctx.author, "unban", reason)
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="User Unbanned",
|
||||||
|
description=f"{user} has been unbanned.",
|
||||||
|
color=discord.Color.green(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
embed.add_field(name="Reason", value=reason, inline=False)
|
||||||
|
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||||
|
|
||||||
|
await ctx.send(embed=embed)
|
||||||
|
|
||||||
|
except discord.NotFound:
|
||||||
|
await ctx.send("User not found or not banned.")
|
||||||
|
except discord.Forbidden:
|
||||||
|
await ctx.send("I don't have permission to unban this user.")
|
||||||
|
|
||||||
|
@commands.command(name="purge", aliases=["clear"])
|
||||||
|
@commands.has_permissions(manage_messages=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def purge(self, ctx: commands.Context, amount: int) -> None:
|
||||||
|
"""Delete multiple messages at once (max 100)."""
|
||||||
|
if amount < 1 or amount > 100:
|
||||||
|
await ctx.send("Please specify a number between 1 and 100.")
|
||||||
|
return
|
||||||
|
|
||||||
|
deleted = await ctx.channel.purge(limit=amount + 1) # +1 to include the command message
|
||||||
|
|
||||||
|
msg = await ctx.send(f"Deleted {len(deleted) - 1} message(s).")
|
||||||
|
await msg.delete(delay=3)
|
||||||
|
|
||||||
|
@commands.command(name="modlogs", aliases=["history"])
|
||||||
|
@commands.has_permissions(kick_members=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def modlogs(self, ctx: commands.Context, member: discord.Member | discord.User) -> None:
|
||||||
|
"""View moderation history for a user."""
|
||||||
|
async with self.bot.database.session() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(ModerationLog)
|
||||||
|
.where(ModerationLog.guild_id == ctx.guild.id, ModerationLog.target_id == member.id)
|
||||||
|
.order_by(ModerationLog.created_at.desc())
|
||||||
|
.limit(10)
|
||||||
|
)
|
||||||
|
logs = result.scalars().all()
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title=f"Moderation History for {member}",
|
||||||
|
color=discord.Color.blue(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if logs:
|
||||||
|
for log in logs:
|
||||||
|
value = f"**Reason:** {log.reason or 'None'}\n**By:** {log.moderator_name}\n*{log.created_at.strftime('%Y-%m-%d %H:%M')}*"
|
||||||
|
embed.add_field(name=f"{log.action.upper()} (#{log.id})", value=value, inline=False)
|
||||||
|
else:
|
||||||
|
embed.description = "No moderation history found."
|
||||||
|
|
||||||
|
await ctx.send(embed=embed)
|
||||||
|
|
||||||
|
|
||||||
|
async def setup(bot: GuardDen) -> None:
|
||||||
|
"""Load the Moderation cog."""
|
||||||
|
await bot.add_cog(Moderation(bot))
|
||||||
423
src/guardden/cogs/verification.py
Normal file
423
src/guardden/cogs/verification.py
Normal file
@@ -0,0 +1,423 @@
|
|||||||
|
"""Verification cog for new member verification."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord import ui
|
||||||
|
from discord.ext import commands, tasks
|
||||||
|
|
||||||
|
from guardden.bot import GuardDen
|
||||||
|
from guardden.services.verification import (
|
||||||
|
ChallengeType,
|
||||||
|
PendingVerification,
|
||||||
|
VerificationService,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VerifyButton(ui.Button["VerificationView"]):
|
||||||
|
"""Button for simple verification."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(
|
||||||
|
style=discord.ButtonStyle.success,
|
||||||
|
label="Verify",
|
||||||
|
custom_id="verify_button",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def callback(self, interaction: discord.Interaction) -> None:
|
||||||
|
if self.view is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
success, message = await self.view.cog.complete_verification(
|
||||||
|
interaction.guild.id,
|
||||||
|
interaction.user.id,
|
||||||
|
"verified",
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
await interaction.response.send_message(message, ephemeral=True)
|
||||||
|
# Disable the button
|
||||||
|
self.disabled = True
|
||||||
|
self.label = "Verified"
|
||||||
|
await interaction.message.edit(view=self.view)
|
||||||
|
else:
|
||||||
|
await interaction.response.send_message(message, ephemeral=True)
|
||||||
|
|
||||||
|
|
||||||
|
class EmojiButton(ui.Button["EmojiVerificationView"]):
|
||||||
|
"""Button for emoji selection verification."""
|
||||||
|
|
||||||
|
def __init__(self, emoji: str, row: int = 0) -> None:
|
||||||
|
super().__init__(
|
||||||
|
style=discord.ButtonStyle.secondary,
|
||||||
|
label=emoji,
|
||||||
|
custom_id=f"emoji_{emoji}",
|
||||||
|
row=row,
|
||||||
|
)
|
||||||
|
self.emoji_value = emoji
|
||||||
|
|
||||||
|
async def callback(self, interaction: discord.Interaction) -> None:
|
||||||
|
if self.view is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
success, message = await self.view.cog.complete_verification(
|
||||||
|
interaction.guild.id,
|
||||||
|
interaction.user.id,
|
||||||
|
self.emoji_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
await interaction.response.send_message(message, ephemeral=True)
|
||||||
|
# Disable all buttons
|
||||||
|
for item in self.view.children:
|
||||||
|
if isinstance(item, ui.Button):
|
||||||
|
item.disabled = True
|
||||||
|
await interaction.message.edit(view=self.view)
|
||||||
|
else:
|
||||||
|
await interaction.response.send_message(message, ephemeral=True)
|
||||||
|
|
||||||
|
|
||||||
|
class VerificationView(ui.View):
|
||||||
|
"""View for button verification."""
|
||||||
|
|
||||||
|
def __init__(self, cog: "Verification", timeout: float = 600) -> None:
|
||||||
|
super().__init__(timeout=timeout)
|
||||||
|
self.cog = cog
|
||||||
|
self.add_item(VerifyButton())
|
||||||
|
|
||||||
|
|
||||||
|
class EmojiVerificationView(ui.View):
|
||||||
|
"""View for emoji selection verification."""
|
||||||
|
|
||||||
|
def __init__(self, cog: "Verification", options: list[str], timeout: float = 600) -> None:
|
||||||
|
super().__init__(timeout=timeout)
|
||||||
|
self.cog = cog
|
||||||
|
for i, emoji in enumerate(options):
|
||||||
|
self.add_item(EmojiButton(emoji, row=i // 4))
|
||||||
|
|
||||||
|
|
||||||
|
class CaptchaModal(ui.Modal):
|
||||||
|
"""Modal for captcha/math input."""
|
||||||
|
|
||||||
|
answer = ui.TextInput(
|
||||||
|
label="Your Answer",
|
||||||
|
placeholder="Enter the answer here...",
|
||||||
|
max_length=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, cog: "Verification", title: str = "Verification") -> None:
|
||||||
|
super().__init__(title=title)
|
||||||
|
self.cog = cog
|
||||||
|
|
||||||
|
async def on_submit(self, interaction: discord.Interaction) -> None:
|
||||||
|
success, message = await self.cog.complete_verification(
|
||||||
|
interaction.guild.id,
|
||||||
|
interaction.user.id,
|
||||||
|
self.answer.value,
|
||||||
|
)
|
||||||
|
await interaction.response.send_message(message, ephemeral=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerButton(ui.Button["AnswerView"]):
|
||||||
|
"""Button to open the answer modal."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(
|
||||||
|
style=discord.ButtonStyle.primary,
|
||||||
|
label="Submit Answer",
|
||||||
|
custom_id="submit_answer",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def callback(self, interaction: discord.Interaction) -> None:
|
||||||
|
if self.view is None:
|
||||||
|
return
|
||||||
|
modal = CaptchaModal(self.view.cog)
|
||||||
|
await interaction.response.send_modal(modal)
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerView(ui.View):
|
||||||
|
"""View with button to open answer modal."""
|
||||||
|
|
||||||
|
def __init__(self, cog: "Verification", timeout: float = 600) -> None:
|
||||||
|
super().__init__(timeout=timeout)
|
||||||
|
self.cog = cog
|
||||||
|
self.add_item(AnswerButton())
|
||||||
|
|
||||||
|
|
||||||
|
class Verification(commands.Cog):
|
||||||
|
"""Member verification system."""
|
||||||
|
|
||||||
|
def __init__(self, bot: GuardDen) -> None:
|
||||||
|
self.bot = bot
|
||||||
|
self.service = VerificationService()
|
||||||
|
self.cleanup_task.start()
|
||||||
|
|
||||||
|
def cog_unload(self) -> None:
|
||||||
|
self.cleanup_task.cancel()
|
||||||
|
|
||||||
|
@tasks.loop(minutes=5)
|
||||||
|
async def cleanup_task(self) -> None:
|
||||||
|
"""Periodically clean up expired verifications."""
|
||||||
|
count = self.service.cleanup_expired()
|
||||||
|
if count > 0:
|
||||||
|
logger.debug(f"Cleaned up {count} expired verifications")
|
||||||
|
|
||||||
|
@cleanup_task.before_loop
|
||||||
|
async def before_cleanup(self) -> None:
|
||||||
|
await self.bot.wait_until_ready()
|
||||||
|
|
||||||
|
async def complete_verification(
|
||||||
|
self, guild_id: int, user_id: int, response: str
|
||||||
|
) -> tuple[bool, str]:
|
||||||
|
"""Complete a verification and assign role if successful."""
|
||||||
|
success, message = self.service.verify(guild_id, user_id, response)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
# Assign verified role
|
||||||
|
guild = self.bot.get_guild(guild_id)
|
||||||
|
if guild:
|
||||||
|
member = guild.get_member(user_id)
|
||||||
|
config = await self.bot.guild_config.get_config(guild_id)
|
||||||
|
|
||||||
|
if member and config and config.verified_role_id:
|
||||||
|
role = guild.get_role(config.verified_role_id)
|
||||||
|
if role:
|
||||||
|
try:
|
||||||
|
await member.add_roles(role, reason="Verification completed")
|
||||||
|
logger.info(f"Verified {member} in {guild.name}")
|
||||||
|
except discord.Forbidden:
|
||||||
|
logger.warning(f"Cannot assign verified role in {guild.name}")
|
||||||
|
|
||||||
|
return success, message
|
||||||
|
|
||||||
|
async def send_verification(
|
||||||
|
self,
|
||||||
|
member: discord.Member,
|
||||||
|
channel: discord.TextChannel,
|
||||||
|
challenge_type: ChallengeType,
|
||||||
|
) -> None:
|
||||||
|
"""Send a verification challenge to a member."""
|
||||||
|
pending = self.service.create_challenge(
|
||||||
|
user_id=member.id,
|
||||||
|
guild_id=member.guild.id,
|
||||||
|
challenge_type=challenge_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Verification Required",
|
||||||
|
description=pending.challenge.question,
|
||||||
|
color=discord.Color.blue(),
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
embed.set_footer(
|
||||||
|
text=f"Expires in 10 minutes • {pending.challenge.max_attempts} attempts allowed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create appropriate view based on challenge type
|
||||||
|
if challenge_type == ChallengeType.BUTTON:
|
||||||
|
view = VerificationView(self)
|
||||||
|
elif challenge_type == ChallengeType.EMOJI:
|
||||||
|
view = EmojiVerificationView(self, pending.challenge.options)
|
||||||
|
else:
|
||||||
|
# Captcha or Math - use modal
|
||||||
|
view = AnswerView(self)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to DM the user first
|
||||||
|
dm_channel = await member.create_dm()
|
||||||
|
msg = await dm_channel.send(embed=embed, view=view)
|
||||||
|
pending.message_id = msg.id
|
||||||
|
pending.channel_id = dm_channel.id
|
||||||
|
except discord.Forbidden:
|
||||||
|
# Fall back to channel mention
|
||||||
|
msg = await channel.send(
|
||||||
|
content=member.mention,
|
||||||
|
embed=embed,
|
||||||
|
view=view,
|
||||||
|
)
|
||||||
|
pending.message_id = msg.id
|
||||||
|
pending.channel_id = channel.id
|
||||||
|
|
||||||
|
@commands.Cog.listener()
|
||||||
|
async def on_member_join(self, member: discord.Member) -> None:
|
||||||
|
"""Handle new member joins for verification."""
|
||||||
|
if member.bot:
|
||||||
|
return
|
||||||
|
|
||||||
|
config = await self.bot.guild_config.get_config(member.guild.id)
|
||||||
|
if not config or not config.verification_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Determine verification channel
|
||||||
|
channel_id = config.welcome_channel_id or config.log_channel_id
|
||||||
|
if not channel_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
channel = member.guild.get_channel(channel_id)
|
||||||
|
if not channel or not isinstance(channel, discord.TextChannel):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get challenge type from config
|
||||||
|
try:
|
||||||
|
challenge_type = ChallengeType(config.verification_type)
|
||||||
|
except ValueError:
|
||||||
|
challenge_type = ChallengeType.BUTTON
|
||||||
|
|
||||||
|
await self.send_verification(member, channel, challenge_type)
|
||||||
|
|
||||||
|
@commands.group(name="verify", invoke_without_command=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def verify_cmd(self, ctx: commands.Context) -> None:
|
||||||
|
"""Request a verification challenge."""
|
||||||
|
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||||
|
|
||||||
|
if not config or not config.verification_enabled:
|
||||||
|
await ctx.send("Verification is not enabled on this server.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if already verified
|
||||||
|
if config.verified_role_id:
|
||||||
|
role = ctx.guild.get_role(config.verified_role_id)
|
||||||
|
if role and role in ctx.author.roles:
|
||||||
|
await ctx.send("You are already verified!")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check for existing pending verification
|
||||||
|
pending = self.service.get_pending(ctx.guild.id, ctx.author.id)
|
||||||
|
if pending and not pending.challenge.is_expired:
|
||||||
|
await ctx.send("You already have a pending verification. Please complete it first.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get challenge type
|
||||||
|
try:
|
||||||
|
challenge_type = ChallengeType(config.verification_type)
|
||||||
|
except ValueError:
|
||||||
|
challenge_type = ChallengeType.BUTTON
|
||||||
|
|
||||||
|
await self.send_verification(ctx.author, ctx.channel, challenge_type)
|
||||||
|
await ctx.message.delete(delay=1)
|
||||||
|
|
||||||
|
@verify_cmd.command(name="setup")
|
||||||
|
@commands.has_permissions(administrator=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def verify_setup(self, ctx: commands.Context) -> None:
|
||||||
|
"""View verification setup status."""
|
||||||
|
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||||
|
|
||||||
|
embed = discord.Embed(
|
||||||
|
title="Verification Setup",
|
||||||
|
color=discord.Color.blue(),
|
||||||
|
)
|
||||||
|
|
||||||
|
embed.add_field(
|
||||||
|
name="Enabled",
|
||||||
|
value="✅ Yes" if config and config.verification_enabled else "❌ No",
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
embed.add_field(
|
||||||
|
name="Type",
|
||||||
|
value=config.verification_type if config else "button",
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config and config.verified_role_id:
|
||||||
|
role = ctx.guild.get_role(config.verified_role_id)
|
||||||
|
embed.add_field(
|
||||||
|
name="Verified Role",
|
||||||
|
value=role.mention if role else "Not found",
|
||||||
|
inline=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embed.add_field(name="Verified Role", value="Not set", inline=True)
|
||||||
|
|
||||||
|
pending_count = self.service.get_pending_count(ctx.guild.id)
|
||||||
|
embed.add_field(name="Pending Verifications", value=str(pending_count), inline=True)
|
||||||
|
|
||||||
|
await ctx.send(embed=embed)
|
||||||
|
|
||||||
|
@verify_cmd.command(name="enable")
|
||||||
|
@commands.has_permissions(administrator=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def verify_enable(self, ctx: commands.Context) -> None:
|
||||||
|
"""Enable verification for new members."""
|
||||||
|
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||||
|
|
||||||
|
if not config or not config.verified_role_id:
|
||||||
|
await ctx.send("Please set a verified role first with `!verify role @role`")
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.bot.guild_config.update_settings(ctx.guild.id, verification_enabled=True)
|
||||||
|
await ctx.send("✅ Verification enabled for new members.")
|
||||||
|
|
||||||
|
@verify_cmd.command(name="disable")
|
||||||
|
@commands.has_permissions(administrator=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def verify_disable(self, ctx: commands.Context) -> None:
|
||||||
|
"""Disable verification."""
|
||||||
|
await self.bot.guild_config.update_settings(ctx.guild.id, verification_enabled=False)
|
||||||
|
await ctx.send("❌ Verification disabled.")
|
||||||
|
|
||||||
|
@verify_cmd.command(name="role")
|
||||||
|
@commands.has_permissions(administrator=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def verify_role(self, ctx: commands.Context, role: discord.Role) -> None:
|
||||||
|
"""Set the role given upon verification."""
|
||||||
|
await self.bot.guild_config.update_settings(ctx.guild.id, verified_role_id=role.id)
|
||||||
|
await ctx.send(f"Verified role set to {role.mention}")
|
||||||
|
|
||||||
|
@verify_cmd.command(name="type")
|
||||||
|
@commands.has_permissions(administrator=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def verify_type(self, ctx: commands.Context, vtype: str) -> None:
|
||||||
|
"""Set verification type (button, captcha, math, emoji)."""
|
||||||
|
try:
|
||||||
|
challenge_type = ChallengeType(vtype.lower())
|
||||||
|
except ValueError:
|
||||||
|
valid = ", ".join(t.value for t in ChallengeType if t != ChallengeType.QUESTIONS)
|
||||||
|
await ctx.send(f"Invalid type. Valid options: {valid}")
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.bot.guild_config.update_settings(
|
||||||
|
ctx.guild.id, verification_type=challenge_type.value
|
||||||
|
)
|
||||||
|
await ctx.send(f"Verification type set to **{challenge_type.value}**")
|
||||||
|
|
||||||
|
@verify_cmd.command(name="test")
|
||||||
|
@commands.has_permissions(administrator=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def verify_test(self, ctx: commands.Context, vtype: str = "button") -> None:
|
||||||
|
"""Test verification (sends challenge to you)."""
|
||||||
|
try:
|
||||||
|
challenge_type = ChallengeType(vtype.lower())
|
||||||
|
except ValueError:
|
||||||
|
challenge_type = ChallengeType.BUTTON
|
||||||
|
|
||||||
|
await self.send_verification(ctx.author, ctx.channel, challenge_type)
|
||||||
|
|
||||||
|
@verify_cmd.command(name="reset")
|
||||||
|
@commands.has_permissions(kick_members=True)
|
||||||
|
@commands.guild_only()
|
||||||
|
async def verify_reset(self, ctx: commands.Context, member: discord.Member) -> None:
|
||||||
|
"""Reset verification for a member (remove role and cancel pending)."""
|
||||||
|
# Cancel any pending verification
|
||||||
|
self.service.cancel(ctx.guild.id, member.id)
|
||||||
|
|
||||||
|
# Remove verified role
|
||||||
|
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||||
|
if config and config.verified_role_id:
|
||||||
|
role = ctx.guild.get_role(config.verified_role_id)
|
||||||
|
if role and role in member.roles:
|
||||||
|
try:
|
||||||
|
await member.remove_roles(role, reason=f"Verification reset by {ctx.author}")
|
||||||
|
except discord.Forbidden:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await ctx.send(f"Reset verification for {member.mention}")
|
||||||
|
|
||||||
|
|
||||||
|
async def setup(bot: GuardDen) -> None:
|
||||||
|
"""Load the Verification cog."""
|
||||||
|
await bot.add_cog(Verification(bot))
|
||||||
50
src/guardden/config.py
Normal file
50
src/guardden/config.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""Configuration management for GuardDen."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import Field, SecretStr
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Application settings loaded from environment variables."""
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_file=".env",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
case_sensitive=False,
|
||||||
|
env_prefix="GUARDDEN_",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Discord settings
|
||||||
|
discord_token: SecretStr = Field(..., description="Discord bot token")
|
||||||
|
discord_prefix: str = Field(default="!", description="Default command prefix")
|
||||||
|
|
||||||
|
# Database settings
|
||||||
|
database_url: SecretStr = Field(
|
||||||
|
default=SecretStr("postgresql://guardden:guardden@localhost:5432/guardden"),
|
||||||
|
description="PostgreSQL connection URL",
|
||||||
|
)
|
||||||
|
database_pool_min: int = Field(default=5, description="Minimum database pool size")
|
||||||
|
database_pool_max: int = Field(default=20, description="Maximum database pool size")
|
||||||
|
|
||||||
|
# AI settings (optional)
|
||||||
|
ai_provider: Literal["anthropic", "openai", "none"] = Field(
|
||||||
|
default="none", description="AI provider for content moderation"
|
||||||
|
)
|
||||||
|
anthropic_api_key: SecretStr | None = Field(default=None, description="Anthropic API key")
|
||||||
|
openai_api_key: SecretStr | None = Field(default=None, description="OpenAI API key")
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(
|
||||||
|
default="INFO", description="Logging level"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Paths
|
||||||
|
data_dir: Path = Field(default=Path("data"), description="Data directory for persistent files")
|
||||||
|
|
||||||
|
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
"""Get application settings instance."""
|
||||||
|
return Settings()
|
||||||
15
src/guardden/models/__init__.py
Normal file
15
src/guardden/models/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""Database models for GuardDen."""
|
||||||
|
|
||||||
|
from guardden.models.base import Base
|
||||||
|
from guardden.models.guild import BannedWord, Guild, GuildSettings
|
||||||
|
from guardden.models.moderation import ModerationLog, Strike, UserNote
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Base",
|
||||||
|
"Guild",
|
||||||
|
"GuildSettings",
|
||||||
|
"BannedWord",
|
||||||
|
"ModerationLog",
|
||||||
|
"Strike",
|
||||||
|
"UserNote",
|
||||||
|
]
|
||||||
32
src/guardden/models/base.py
Normal file
32
src/guardden/models/base.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Base model and database utilities."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import BigInteger, DateTime, func
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
"""Base class for all database models."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TimestampMixin:
|
||||||
|
"""Mixin that adds created_at and updated_at timestamps."""
|
||||||
|
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
onupdate=func.now(),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Type alias for Discord snowflake IDs (64-bit integers)
|
||||||
|
SnowflakeID = BigInteger
|
||||||
117
src/guardden/models/guild.py
Normal file
117
src/guardden/models/guild.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
"""Guild-related database models."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean, ForeignKey, Integer, String, Text
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from guardden.models.base import Base, SnowflakeID, TimestampMixin
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from guardden.models.moderation import ModerationLog, Strike
|
||||||
|
|
||||||
|
|
||||||
|
class Guild(Base, TimestampMixin):
|
||||||
|
"""Represents a Discord guild (server) configuration."""
|
||||||
|
|
||||||
|
__tablename__ = "guilds"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(SnowflakeID, primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
owner_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||||
|
premium: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
settings: Mapped["GuildSettings"] = relationship(
|
||||||
|
back_populates="guild", uselist=False, cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
banned_words: Mapped[list["BannedWord"]] = relationship(
|
||||||
|
back_populates="guild", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
moderation_logs: Mapped[list["ModerationLog"]] = relationship(
|
||||||
|
back_populates="guild", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
strikes: Mapped[list["Strike"]] = relationship(
|
||||||
|
back_populates="guild", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GuildSettings(Base, TimestampMixin):
|
||||||
|
"""Per-guild bot settings and configuration."""
|
||||||
|
|
||||||
|
__tablename__ = "guild_settings"
|
||||||
|
|
||||||
|
guild_id: Mapped[int] = mapped_column(
|
||||||
|
SnowflakeID, ForeignKey("guilds.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# General settings
|
||||||
|
prefix: Mapped[str] = mapped_column(String(10), default="!", nullable=False)
|
||||||
|
locale: Mapped[str] = mapped_column(String(10), default="en", nullable=False)
|
||||||
|
|
||||||
|
# Channel configuration (stored as snowflake IDs)
|
||||||
|
log_channel_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||||
|
mod_log_channel_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||||
|
welcome_channel_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||||
|
|
||||||
|
# Role configuration
|
||||||
|
mute_role_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||||
|
verified_role_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||||
|
mod_role_ids: Mapped[dict] = mapped_column(JSONB, default=list, nullable=False)
|
||||||
|
|
||||||
|
# Moderation settings
|
||||||
|
automod_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||||
|
anti_spam_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||||
|
link_filter_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||||
|
|
||||||
|
# Strike thresholds (actions at each threshold)
|
||||||
|
strike_actions: Mapped[dict] = mapped_column(
|
||||||
|
JSONB,
|
||||||
|
default=lambda: {
|
||||||
|
"1": {"action": "warn"},
|
||||||
|
"3": {"action": "timeout", "duration": 3600},
|
||||||
|
"5": {"action": "kick"},
|
||||||
|
"7": {"action": "ban"},
|
||||||
|
},
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# AI moderation settings
|
||||||
|
ai_moderation_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||||
|
ai_sensitivity: Mapped[int] = mapped_column(Integer, default=50, nullable=False) # 0-100 scale
|
||||||
|
nsfw_detection_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||||
|
|
||||||
|
# Verification settings
|
||||||
|
verification_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||||
|
verification_type: Mapped[str] = mapped_column(
|
||||||
|
String(20), default="button", nullable=False
|
||||||
|
) # button, captcha, questions
|
||||||
|
|
||||||
|
# Relationship
|
||||||
|
guild: Mapped["Guild"] = relationship(back_populates="settings")
|
||||||
|
|
||||||
|
|
||||||
|
class BannedWord(Base, TimestampMixin):
|
||||||
|
"""Banned words/phrases for a guild with regex support."""
|
||||||
|
|
||||||
|
__tablename__ = "banned_words"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
guild_id: Mapped[int] = mapped_column(
|
||||||
|
SnowflakeID, ForeignKey("guilds.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
|
||||||
|
pattern: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
is_regex: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||||
|
action: Mapped[str] = mapped_column(
|
||||||
|
String(20), default="delete", nullable=False
|
||||||
|
) # delete, warn, strike
|
||||||
|
reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
|
||||||
|
# Who added this and when
|
||||||
|
added_by: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||||
|
|
||||||
|
# Relationship
|
||||||
|
guild: Mapped["Guild"] = relationship(back_populates="banned_words")
|
||||||
101
src/guardden/models/moderation.py
Normal file
101
src/guardden/models/moderation.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""Moderation-related database models."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, Text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from guardden.models.base import Base, SnowflakeID, TimestampMixin
|
||||||
|
from guardden.models.guild import Guild
|
||||||
|
|
||||||
|
|
||||||
|
class ModAction(str, Enum):
|
||||||
|
"""Types of moderation actions."""
|
||||||
|
|
||||||
|
WARN = "warn"
|
||||||
|
TIMEOUT = "timeout"
|
||||||
|
KICK = "kick"
|
||||||
|
BAN = "ban"
|
||||||
|
UNBAN = "unban"
|
||||||
|
UNMUTE = "unmute"
|
||||||
|
NOTE = "note"
|
||||||
|
STRIKE = "strike"
|
||||||
|
DELETE = "delete"
|
||||||
|
|
||||||
|
|
||||||
|
class ModerationLog(Base, TimestampMixin):
|
||||||
|
"""Log of all moderation actions taken."""
|
||||||
|
|
||||||
|
__tablename__ = "moderation_logs"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
guild_id: Mapped[int] = mapped_column(
|
||||||
|
SnowflakeID, ForeignKey("guilds.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Target and moderator
|
||||||
|
target_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||||
|
target_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
moderator_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||||
|
moderator_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
|
||||||
|
# Action details
|
||||||
|
action: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||||
|
reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
duration: Mapped[int | None] = mapped_column(Integer, nullable=True) # Duration in seconds
|
||||||
|
expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
# Context
|
||||||
|
channel_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||||
|
message_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||||
|
message_content: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
|
||||||
|
# Was this an automatic action?
|
||||||
|
is_automatic: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||||
|
|
||||||
|
# Relationship
|
||||||
|
guild: Mapped["Guild"] = relationship(back_populates="moderation_logs")
|
||||||
|
|
||||||
|
|
||||||
|
class Strike(Base, TimestampMixin):
|
||||||
|
"""User strikes/warnings tracking."""
|
||||||
|
|
||||||
|
__tablename__ = "strikes"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
guild_id: Mapped[int] = mapped_column(
|
||||||
|
SnowflakeID, ForeignKey("guilds.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||||
|
user_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
moderator_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||||
|
|
||||||
|
reason: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
points: Mapped[int] = mapped_column(Integer, default=1, nullable=False)
|
||||||
|
|
||||||
|
# Strikes can expire
|
||||||
|
expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||||
|
|
||||||
|
# Reference to the moderation log entry
|
||||||
|
mod_log_id: Mapped[int | None] = mapped_column(
|
||||||
|
Integer, ForeignKey("moderation_logs.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relationship
|
||||||
|
guild: Mapped["Guild"] = relationship(back_populates="strikes")
|
||||||
|
|
||||||
|
|
||||||
|
class UserNote(Base, TimestampMixin):
|
||||||
|
"""Moderator notes on users."""
|
||||||
|
|
||||||
|
__tablename__ = "user_notes"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
guild_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||||
|
|
||||||
|
user_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||||
|
moderator_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||||
|
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
16
src/guardden/services/__init__.py
Normal file
16
src/guardden/services/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Services for GuardDen."""
|
||||||
|
|
||||||
|
from guardden.services.automod import AutomodService
|
||||||
|
from guardden.services.database import Database
|
||||||
|
from guardden.services.ratelimit import RateLimiter, get_rate_limiter, ratelimit
|
||||||
|
from guardden.services.verification import ChallengeType, VerificationService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AutomodService",
|
||||||
|
"ChallengeType",
|
||||||
|
"Database",
|
||||||
|
"RateLimiter",
|
||||||
|
"VerificationService",
|
||||||
|
"get_rate_limiter",
|
||||||
|
"ratelimit",
|
||||||
|
]
|
||||||
6
src/guardden/services/ai/__init__.py
Normal file
6
src/guardden/services/ai/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""AI services for content moderation."""
|
||||||
|
|
||||||
|
from guardden.services.ai.base import AIProvider, ModerationResult
|
||||||
|
from guardden.services.ai.factory import create_ai_provider
|
||||||
|
|
||||||
|
__all__ = ["AIProvider", "ModerationResult", "create_ai_provider"]
|
||||||
261
src/guardden/services/ai/anthropic_provider.py
Normal file
261
src/guardden/services/ai/anthropic_provider.py
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
"""Anthropic Claude AI provider implementation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from guardden.services.ai.base import (
|
||||||
|
AIProvider,
|
||||||
|
ContentCategory,
|
||||||
|
ImageAnalysisResult,
|
||||||
|
ModerationResult,
|
||||||
|
PhishingAnalysisResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Content moderation system prompt
|
||||||
|
MODERATION_SYSTEM_PROMPT = """You are a content moderation AI for a Discord server. Analyze the given message and determine if it violates community guidelines.
|
||||||
|
|
||||||
|
Categories to check:
|
||||||
|
- harassment: Personal attacks, bullying, intimidation
|
||||||
|
- hate_speech: Discrimination, slurs, dehumanization based on identity
|
||||||
|
- sexual: Explicit sexual content, sexual solicitation
|
||||||
|
- violence: Threats, graphic violence, encouraging harm
|
||||||
|
- self_harm: Suicide, self-injury content or encouragement
|
||||||
|
- spam: Repetitive, promotional, or low-quality content
|
||||||
|
- scam: Phishing attempts, fraudulent offers, impersonation
|
||||||
|
- misinformation: Dangerous false information
|
||||||
|
|
||||||
|
Respond in this exact JSON format:
|
||||||
|
{
|
||||||
|
"is_flagged": true/false,
|
||||||
|
"confidence": 0.0-1.0,
|
||||||
|
"categories": ["category1", "category2"],
|
||||||
|
"explanation": "Brief explanation",
|
||||||
|
"suggested_action": "none/warn/delete/timeout/ban"
|
||||||
|
}
|
||||||
|
|
||||||
|
Be balanced - flag genuinely problematic content but allow normal conversation, jokes, and mild language. Consider context."""
|
||||||
|
|
||||||
|
IMAGE_ANALYSIS_PROMPT = """Analyze this image for content moderation purposes. Check for:
|
||||||
|
- NSFW content (nudity, sexual content)
|
||||||
|
- Violence or gore
|
||||||
|
- Disturbing or shocking content
|
||||||
|
- Any content inappropriate for a general audience
|
||||||
|
|
||||||
|
Respond in this exact JSON format:
|
||||||
|
{
|
||||||
|
"is_nsfw": true/false,
|
||||||
|
"is_violent": true/false,
|
||||||
|
"is_disturbing": true/false,
|
||||||
|
"confidence": 0.0-1.0,
|
||||||
|
"description": "Brief description of the image",
|
||||||
|
"categories": ["category1", "category2"]
|
||||||
|
}
|
||||||
|
|
||||||
|
Be accurate but not overly sensitive - artistic nudity or mild violence in appropriate contexts may be acceptable."""
|
||||||
|
|
||||||
|
PHISHING_ANALYSIS_PROMPT = """Analyze this URL and message context for phishing or scam indicators.
|
||||||
|
|
||||||
|
Check for:
|
||||||
|
- Domain impersonation (typosquatting, lookalike domains)
|
||||||
|
- Urgency tactics ("act now", "limited time")
|
||||||
|
- Requests for credentials or personal info
|
||||||
|
- Too-good-to-be-true offers
|
||||||
|
- Suspicious redirects or URL shorteners
|
||||||
|
- Mismatched or hidden URLs
|
||||||
|
|
||||||
|
Respond in this exact JSON format:
|
||||||
|
{
|
||||||
|
"is_phishing": true/false,
|
||||||
|
"confidence": 0.0-1.0,
|
||||||
|
"risk_factors": ["factor1", "factor2"],
|
||||||
|
"explanation": "Brief explanation"
|
||||||
|
}"""
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicProvider(AIProvider):
|
||||||
|
"""AI provider using Anthropic's Claude API."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, model: str = "claude-3-haiku-20240307") -> None:
|
||||||
|
"""
|
||||||
|
Initialize Anthropic provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Anthropic API key
|
||||||
|
model: Model to use (default: claude-3-haiku for speed/cost)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import anthropic
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("anthropic package required. Install with: pip install anthropic")
|
||||||
|
|
||||||
|
self.client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||||
|
self.model = model
|
||||||
|
logger.info(f"Initialized Anthropic provider with model: {model}")
|
||||||
|
|
||||||
|
async def _call_api(self, system: str, user_content: Any, max_tokens: int = 500) -> str:
|
||||||
|
"""Make an API call to Claude."""
|
||||||
|
try:
|
||||||
|
message = await self.client.messages.create(
|
||||||
|
model=self.model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
system=system,
|
||||||
|
messages=[{"role": "user", "content": user_content}],
|
||||||
|
)
|
||||||
|
return message.content[0].text
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Anthropic API error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _parse_json_response(self, response: str) -> dict:
|
||||||
|
"""Parse JSON from response, handling markdown code blocks."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Remove markdown code blocks if present
|
||||||
|
text = response.strip()
|
||||||
|
if text.startswith("```"):
|
||||||
|
lines = text.split("\n")
|
||||||
|
# Remove first and last lines (```json and ```)
|
||||||
|
text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
||||||
|
|
||||||
|
return json.loads(text)
|
||||||
|
|
||||||
|
async def moderate_text(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
context: str | None = None,
|
||||||
|
sensitivity: int = 50,
|
||||||
|
) -> ModerationResult:
|
||||||
|
"""Analyze text content for policy violations."""
|
||||||
|
# Adjust prompt based on sensitivity
|
||||||
|
sensitivity_note = ""
|
||||||
|
if sensitivity < 30:
|
||||||
|
sensitivity_note = "\n\nBe lenient - only flag clearly problematic content."
|
||||||
|
elif sensitivity > 70:
|
||||||
|
sensitivity_note = "\n\nBe strict - flag anything potentially problematic."
|
||||||
|
|
||||||
|
system = MODERATION_SYSTEM_PROMPT + sensitivity_note
|
||||||
|
|
||||||
|
user_message = f"Message to analyze:\n{content}"
|
||||||
|
if context:
|
||||||
|
user_message = f"Context: {context}\n\n{user_message}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._call_api(system, user_message)
|
||||||
|
data = self._parse_json_response(response)
|
||||||
|
|
||||||
|
categories = [
|
||||||
|
ContentCategory(cat)
|
||||||
|
for cat in data.get("categories", [])
|
||||||
|
if cat in ContentCategory.__members__.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
return ModerationResult(
|
||||||
|
is_flagged=data.get("is_flagged", False),
|
||||||
|
confidence=float(data.get("confidence", 0.0)),
|
||||||
|
categories=categories,
|
||||||
|
explanation=data.get("explanation", ""),
|
||||||
|
suggested_action=data.get("suggested_action", "none"),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error moderating text: {e}")
|
||||||
|
return ModerationResult(
|
||||||
|
is_flagged=False,
|
||||||
|
explanation=f"Error analyzing content: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def analyze_image(
|
||||||
|
self,
|
||||||
|
image_url: str,
|
||||||
|
sensitivity: int = 50,
|
||||||
|
) -> ImageAnalysisResult:
|
||||||
|
"""Analyze an image for NSFW or inappropriate content."""
|
||||||
|
import base64
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
sensitivity_note = ""
|
||||||
|
if sensitivity < 30:
|
||||||
|
sensitivity_note = "\n\nBe lenient - only flag explicit content."
|
||||||
|
elif sensitivity > 70:
|
||||||
|
sensitivity_note = "\n\nBe strict - flag suggestive content as well."
|
||||||
|
|
||||||
|
system = IMAGE_ANALYSIS_PROMPT + sensitivity_note
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Download image and convert to base64
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(image_url) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
return ImageAnalysisResult(
|
||||||
|
description=f"Failed to download image: HTTP {resp.status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
content_type = resp.content_type or "image/jpeg"
|
||||||
|
image_data = await resp.read()
|
||||||
|
|
||||||
|
# Check file size (max 20MB for Claude)
|
||||||
|
if len(image_data) > 20 * 1024 * 1024:
|
||||||
|
return ImageAnalysisResult(description="Image too large to analyze")
|
||||||
|
|
||||||
|
base64_image = base64.standard_b64encode(image_data).decode("utf-8")
|
||||||
|
|
||||||
|
# Create multimodal message
|
||||||
|
user_content = [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": content_type,
|
||||||
|
"data": base64_image,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "Analyze this image for content moderation."},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await self._call_api(system, user_content)
|
||||||
|
data = self._parse_json_response(response)
|
||||||
|
|
||||||
|
return ImageAnalysisResult(
|
||||||
|
is_nsfw=data.get("is_nsfw", False),
|
||||||
|
is_violent=data.get("is_violent", False),
|
||||||
|
is_disturbing=data.get("is_disturbing", False),
|
||||||
|
confidence=float(data.get("confidence", 0.0)),
|
||||||
|
description=data.get("description", ""),
|
||||||
|
categories=data.get("categories", []),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error analyzing image: {e}")
|
||||||
|
return ImageAnalysisResult(description=f"Error analyzing image: {str(e)}")
|
||||||
|
|
||||||
|
async def analyze_phishing(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
message_content: str | None = None,
|
||||||
|
) -> PhishingAnalysisResult:
|
||||||
|
"""Analyze a URL for phishing/scam indicators."""
|
||||||
|
user_message = f"URL to analyze: {url}"
|
||||||
|
if message_content:
|
||||||
|
user_message += f"\n\nFull message context:\n{message_content}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._call_api(PHISHING_ANALYSIS_PROMPT, user_message)
|
||||||
|
data = self._parse_json_response(response)
|
||||||
|
|
||||||
|
return PhishingAnalysisResult(
|
||||||
|
is_phishing=data.get("is_phishing", False),
|
||||||
|
confidence=float(data.get("confidence", 0.0)),
|
||||||
|
risk_factors=data.get("risk_factors", []),
|
||||||
|
explanation=data.get("explanation", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error analyzing phishing: {e}")
|
||||||
|
return PhishingAnalysisResult(explanation=f"Error analyzing URL: {str(e)}")
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Clean up resources."""
|
||||||
|
await self.client.close()
|
||||||
149
src/guardden/services/ai/base.py
Normal file
149
src/guardden/services/ai/base.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
"""Base classes for AI providers."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
|
class ContentCategory(str, Enum):
|
||||||
|
"""Categories of problematic content."""
|
||||||
|
|
||||||
|
SAFE = "safe"
|
||||||
|
HARASSMENT = "harassment"
|
||||||
|
HATE_SPEECH = "hate_speech"
|
||||||
|
SEXUAL = "sexual"
|
||||||
|
VIOLENCE = "violence"
|
||||||
|
SELF_HARM = "self_harm"
|
||||||
|
SPAM = "spam"
|
||||||
|
SCAM = "scam"
|
||||||
|
MISINFORMATION = "misinformation"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModerationResult:
|
||||||
|
"""Result of AI content moderation."""
|
||||||
|
|
||||||
|
is_flagged: bool = False
|
||||||
|
confidence: float = 0.0 # 0.0 to 1.0
|
||||||
|
categories: list[ContentCategory] = field(default_factory=list)
|
||||||
|
explanation: str = ""
|
||||||
|
suggested_action: Literal["none", "warn", "delete", "timeout", "ban"] = "none"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def severity(self) -> int:
|
||||||
|
"""Get severity score 0-100 based on confidence and categories."""
|
||||||
|
if not self.is_flagged:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Base severity from confidence
|
||||||
|
severity = int(self.confidence * 50)
|
||||||
|
|
||||||
|
# Add severity based on category
|
||||||
|
high_severity = {
|
||||||
|
ContentCategory.HATE_SPEECH,
|
||||||
|
ContentCategory.SELF_HARM,
|
||||||
|
ContentCategory.SCAM,
|
||||||
|
}
|
||||||
|
medium_severity = {
|
||||||
|
ContentCategory.HARASSMENT,
|
||||||
|
ContentCategory.VIOLENCE,
|
||||||
|
ContentCategory.SEXUAL,
|
||||||
|
}
|
||||||
|
|
||||||
|
for cat in self.categories:
|
||||||
|
if cat in high_severity:
|
||||||
|
severity += 30
|
||||||
|
elif cat in medium_severity:
|
||||||
|
severity += 20
|
||||||
|
else:
|
||||||
|
severity += 10
|
||||||
|
|
||||||
|
return min(severity, 100)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageAnalysisResult:
|
||||||
|
"""Result of AI image analysis."""
|
||||||
|
|
||||||
|
is_nsfw: bool = False
|
||||||
|
is_violent: bool = False
|
||||||
|
is_disturbing: bool = False
|
||||||
|
confidence: float = 0.0
|
||||||
|
description: str = ""
|
||||||
|
categories: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PhishingAnalysisResult:
|
||||||
|
"""Result of AI phishing/scam analysis."""
|
||||||
|
|
||||||
|
is_phishing: bool = False
|
||||||
|
confidence: float = 0.0
|
||||||
|
risk_factors: list[str] = field(default_factory=list)
|
||||||
|
explanation: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class AIProvider(ABC):
|
||||||
|
"""Abstract base class for AI providers."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def moderate_text(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
context: str | None = None,
|
||||||
|
sensitivity: int = 50,
|
||||||
|
) -> ModerationResult:
|
||||||
|
"""
|
||||||
|
Analyze text content for policy violations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The text to analyze
|
||||||
|
context: Optional context about the conversation/server
|
||||||
|
sensitivity: 0-100, higher means more strict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModerationResult with analysis
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def analyze_image(
|
||||||
|
self,
|
||||||
|
image_url: str,
|
||||||
|
sensitivity: int = 50,
|
||||||
|
) -> ImageAnalysisResult:
|
||||||
|
"""
|
||||||
|
Analyze an image for NSFW or inappropriate content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_url: URL of the image to analyze
|
||||||
|
sensitivity: 0-100, higher means more strict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImageAnalysisResult with analysis
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def analyze_phishing(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
message_content: str | None = None,
|
||||||
|
) -> PhishingAnalysisResult:
|
||||||
|
"""
|
||||||
|
Analyze a URL for phishing/scam indicators.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The URL to analyze
|
||||||
|
message_content: Optional full message for context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PhishingAnalysisResult with analysis
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Clean up resources."""
|
||||||
|
pass
|
||||||
67
src/guardden/services/ai/factory.py
Normal file
67
src/guardden/services/ai/factory.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""Factory for creating AI providers."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from guardden.services.ai.base import AIProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NullProvider(AIProvider):
|
||||||
|
"""Null provider that does nothing (for when AI is disabled)."""
|
||||||
|
|
||||||
|
async def moderate_text(self, content, context=None, sensitivity=50):
|
||||||
|
from guardden.services.ai.base import ModerationResult
|
||||||
|
|
||||||
|
return ModerationResult()
|
||||||
|
|
||||||
|
async def analyze_image(self, image_url, sensitivity=50):
|
||||||
|
from guardden.services.ai.base import ImageAnalysisResult
|
||||||
|
|
||||||
|
return ImageAnalysisResult()
|
||||||
|
|
||||||
|
async def analyze_phishing(self, url, message_content=None):
|
||||||
|
from guardden.services.ai.base import PhishingAnalysisResult
|
||||||
|
|
||||||
|
return PhishingAnalysisResult()
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def create_ai_provider(
|
||||||
|
provider: Literal["anthropic", "openai", "none"],
|
||||||
|
api_key: str | None = None,
|
||||||
|
) -> AIProvider:
|
||||||
|
"""
|
||||||
|
Create an AI provider instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: The provider type to create
|
||||||
|
api_key: API key for the provider
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AIProvider instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If provider is unknown or API key is missing
|
||||||
|
"""
|
||||||
|
if provider == "none":
|
||||||
|
logger.info("AI moderation disabled")
|
||||||
|
return NullProvider()
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError(f"API key required for {provider} provider")
|
||||||
|
|
||||||
|
if provider == "anthropic":
|
||||||
|
from guardden.services.ai.anthropic_provider import AnthropicProvider
|
||||||
|
|
||||||
|
return AnthropicProvider(api_key)
|
||||||
|
|
||||||
|
if provider == "openai":
|
||||||
|
from guardden.services.ai.openai_provider import OpenAIProvider
|
||||||
|
|
||||||
|
return OpenAIProvider(api_key)
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown AI provider: {provider}")
|
||||||
213
src/guardden/services/ai/openai_provider.py
Normal file
213
src/guardden/services/ai/openai_provider.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
"""OpenAI AI provider implementation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from guardden.services.ai.base import (
|
||||||
|
AIProvider,
|
||||||
|
ContentCategory,
|
||||||
|
ImageAnalysisResult,
|
||||||
|
ModerationResult,
|
||||||
|
PhishingAnalysisResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIProvider(AIProvider):
|
||||||
|
"""AI provider using OpenAI's API."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, model: str = "gpt-4o-mini") -> None:
|
||||||
|
"""
|
||||||
|
Initialize OpenAI provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: OpenAI API key
|
||||||
|
model: Model to use (default: gpt-4o-mini for speed/cost)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import openai
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("openai package required. Install with: pip install openai")
|
||||||
|
|
||||||
|
self.client = openai.AsyncOpenAI(api_key=api_key)
|
||||||
|
self.model = model
|
||||||
|
logger.info(f"Initialized OpenAI provider with model: {model}")
|
||||||
|
|
||||||
|
async def _call_api(
|
||||||
|
self,
|
||||||
|
system: str,
|
||||||
|
user_content: Any,
|
||||||
|
max_tokens: int = 500,
|
||||||
|
) -> str:
|
||||||
|
"""Make an API call to OpenAI."""
|
||||||
|
try:
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": system},
|
||||||
|
{"role": "user", "content": user_content},
|
||||||
|
],
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content or ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenAI API error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _parse_json_response(self, response: str) -> dict:
|
||||||
|
"""Parse JSON from response."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
return json.loads(response)
|
||||||
|
|
||||||
|
async def moderate_text(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
context: str | None = None,
|
||||||
|
sensitivity: int = 50,
|
||||||
|
) -> ModerationResult:
|
||||||
|
"""Analyze text content for policy violations."""
|
||||||
|
# First, use OpenAI's built-in moderation API for quick check
|
||||||
|
try:
|
||||||
|
mod_response = await self.client.moderations.create(input=content)
|
||||||
|
results = mod_response.results[0]
|
||||||
|
|
||||||
|
# Map OpenAI categories to our categories
|
||||||
|
category_mapping = {
|
||||||
|
"harassment": ContentCategory.HARASSMENT,
|
||||||
|
"harassment/threatening": ContentCategory.HARASSMENT,
|
||||||
|
"hate": ContentCategory.HATE_SPEECH,
|
||||||
|
"hate/threatening": ContentCategory.HATE_SPEECH,
|
||||||
|
"self-harm": ContentCategory.SELF_HARM,
|
||||||
|
"self-harm/intent": ContentCategory.SELF_HARM,
|
||||||
|
"self-harm/instructions": ContentCategory.SELF_HARM,
|
||||||
|
"sexual": ContentCategory.SEXUAL,
|
||||||
|
"sexual/minors": ContentCategory.SEXUAL,
|
||||||
|
"violence": ContentCategory.VIOLENCE,
|
||||||
|
"violence/graphic": ContentCategory.VIOLENCE,
|
||||||
|
}
|
||||||
|
|
||||||
|
flagged_categories = []
|
||||||
|
max_score = 0.0
|
||||||
|
|
||||||
|
for category, score in results.category_scores.model_dump().items():
|
||||||
|
if score > 0.5: # Threshold
|
||||||
|
if category in category_mapping:
|
||||||
|
flagged_categories.append(category_mapping[category])
|
||||||
|
max_score = max(max_score, score)
|
||||||
|
|
||||||
|
# Adjust threshold based on sensitivity
|
||||||
|
threshold = 0.3 + (0.4 * (100 - sensitivity) / 100) # 0.3 to 0.7
|
||||||
|
|
||||||
|
if results.flagged or max_score > threshold:
|
||||||
|
return ModerationResult(
|
||||||
|
is_flagged=True,
|
||||||
|
confidence=max_score,
|
||||||
|
categories=list(set(flagged_categories)),
|
||||||
|
explanation="Content flagged by moderation API",
|
||||||
|
suggested_action="delete" if max_score > 0.8 else "warn",
|
||||||
|
)
|
||||||
|
|
||||||
|
return ModerationResult(is_flagged=False, confidence=1.0 - max_score)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error moderating text: {e}")
|
||||||
|
return ModerationResult(
|
||||||
|
is_flagged=False,
|
||||||
|
explanation=f"Error analyzing content: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def analyze_image(
|
||||||
|
self,
|
||||||
|
image_url: str,
|
||||||
|
sensitivity: int = 50,
|
||||||
|
) -> ImageAnalysisResult:
|
||||||
|
"""Analyze an image for NSFW or inappropriate content."""
|
||||||
|
system = """Analyze this image for content moderation. Respond in JSON format:
|
||||||
|
{
|
||||||
|
"is_nsfw": true/false,
|
||||||
|
"is_violent": true/false,
|
||||||
|
"is_disturbing": true/false,
|
||||||
|
"confidence": 0.0-1.0,
|
||||||
|
"description": "Brief description",
|
||||||
|
"categories": ["category1"]
|
||||||
|
}"""
|
||||||
|
|
||||||
|
sensitivity_note = ""
|
||||||
|
if sensitivity < 30:
|
||||||
|
sensitivity_note = " Be lenient - only flag explicit content."
|
||||||
|
elif sensitivity > 70:
|
||||||
|
sensitivity_note = " Be strict - flag suggestive content."
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
model="gpt-4o-mini", # Use vision-capable model
|
||||||
|
max_tokens=500,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": system + sensitivity_note},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Analyze this image for moderation."},
|
||||||
|
{"type": "image_url", "image_url": {"url": image_url}},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = self._parse_json_response(response.choices[0].message.content or "{}")
|
||||||
|
|
||||||
|
return ImageAnalysisResult(
|
||||||
|
is_nsfw=data.get("is_nsfw", False),
|
||||||
|
is_violent=data.get("is_violent", False),
|
||||||
|
is_disturbing=data.get("is_disturbing", False),
|
||||||
|
confidence=float(data.get("confidence", 0.0)),
|
||||||
|
description=data.get("description", ""),
|
||||||
|
categories=data.get("categories", []),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error analyzing image: {e}")
|
||||||
|
return ImageAnalysisResult(description=f"Error analyzing image: {str(e)}")
|
||||||
|
|
||||||
|
async def analyze_phishing(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
message_content: str | None = None,
|
||||||
|
) -> PhishingAnalysisResult:
|
||||||
|
"""Analyze a URL for phishing/scam indicators."""
|
||||||
|
system = """Analyze the URL for phishing/scam indicators. Respond in JSON:
|
||||||
|
{
|
||||||
|
"is_phishing": true/false,
|
||||||
|
"confidence": 0.0-1.0,
|
||||||
|
"risk_factors": ["factor1"],
|
||||||
|
"explanation": "Brief explanation"
|
||||||
|
}
|
||||||
|
|
||||||
|
Check for: domain impersonation, urgency tactics, credential requests, too-good-to-be-true offers."""
|
||||||
|
|
||||||
|
user_message = f"URL: {url}"
|
||||||
|
if message_content:
|
||||||
|
user_message += f"\n\nMessage context: {message_content}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._call_api(system, user_message)
|
||||||
|
data = self._parse_json_response(response)
|
||||||
|
|
||||||
|
return PhishingAnalysisResult(
|
||||||
|
is_phishing=data.get("is_phishing", False),
|
||||||
|
confidence=float(data.get("confidence", 0.0)),
|
||||||
|
risk_factors=data.get("risk_factors", []),
|
||||||
|
explanation=data.get("explanation", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error analyzing phishing: {e}")
|
||||||
|
return PhishingAnalysisResult(explanation=f"Error analyzing URL: {str(e)}")
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Clean up resources."""
|
||||||
|
await self.client.close()
|
||||||
301
src/guardden/services/automod.py
Normal file
301
src/guardden/services/automod.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
"""Automod service for content filtering and spam detection."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
import discord
|
||||||
|
|
||||||
|
from guardden.models import BannedWord
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Known scam/phishing patterns
|
||||||
|
SCAM_PATTERNS = [
|
||||||
|
# Discord scam patterns
|
||||||
|
r"discord(?:[-.]?(?:gift|nitro|free|claim|steam))[\w.-]*\.(?!com|gg)[a-z]{2,}",
|
||||||
|
r"(?:free|claim|get)[-.\s]?(?:discord[-.\s]?)?nitro",
|
||||||
|
r"(?:steam|discord)[-.\s]?community[-.\s]?(?:giveaway|gift)",
|
||||||
|
# Generic phishing
|
||||||
|
r"(?:verify|confirm)[-.\s]?(?:your)?[-.\s]?account",
|
||||||
|
r"(?:suspended|locked|limited)[-.\s]?account",
|
||||||
|
r"click[-.\s]?(?:here|this)[-.\s]?(?:to[-.\s]?)?(?:verify|claim|get)",
|
||||||
|
# Crypto scams
|
||||||
|
r"(?:free|claim|airdrop)[-.\s]?(?:crypto|bitcoin|eth|nft)",
|
||||||
|
r"(?:double|2x)[-.\s]?your[-.\s]?(?:crypto|bitcoin|eth)",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Suspicious TLDs often used in phishing
|
||||||
|
SUSPICIOUS_TLDS = {
|
||||||
|
".xyz",
|
||||||
|
".top",
|
||||||
|
".club",
|
||||||
|
".work",
|
||||||
|
".click",
|
||||||
|
".link",
|
||||||
|
".info",
|
||||||
|
".ru",
|
||||||
|
".cn",
|
||||||
|
".tk",
|
||||||
|
".ml",
|
||||||
|
".ga",
|
||||||
|
".cf",
|
||||||
|
".gq",
|
||||||
|
}
|
||||||
|
|
||||||
|
# URL pattern for extraction
|
||||||
|
URL_PATTERN = re.compile(
|
||||||
|
r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+[^\s]*|"
|
||||||
|
r"(?:www\.)?[-\w]+\.(?:com|org|net|io|gg|co|me|tv|xyz|top|club|work|click|link|info|ru|cn)[^\s]*",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SpamRecord(NamedTuple):
|
||||||
|
"""Record of a message for spam tracking."""
|
||||||
|
|
||||||
|
content_hash: str
|
||||||
|
timestamp: datetime
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UserSpamTracker:
|
||||||
|
"""Tracks spam behavior for a single user."""
|
||||||
|
|
||||||
|
messages: list[SpamRecord] = field(default_factory=list)
|
||||||
|
mention_count: int = 0
|
||||||
|
last_mention_time: datetime | None = None
|
||||||
|
duplicate_count: int = 0
|
||||||
|
last_action_time: datetime | None = None
|
||||||
|
|
||||||
|
def cleanup(self, max_age: timedelta = timedelta(minutes=1)) -> None:
|
||||||
|
"""Remove old messages from tracking."""
|
||||||
|
cutoff = datetime.now(timezone.utc) - max_age
|
||||||
|
self.messages = [m for m in self.messages if m.timestamp > cutoff]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AutomodResult:
|
||||||
|
"""Result of automod check."""
|
||||||
|
|
||||||
|
should_delete: bool = False
|
||||||
|
should_warn: bool = False
|
||||||
|
should_strike: bool = False
|
||||||
|
should_timeout: bool = False
|
||||||
|
timeout_duration: int = 0 # seconds
|
||||||
|
reason: str = ""
|
||||||
|
matched_filter: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class AutomodService:
|
||||||
|
"""Service for automatic content moderation."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# Compile scam patterns
|
||||||
|
self._scam_patterns = [re.compile(p, re.IGNORECASE) for p in SCAM_PATTERNS]
|
||||||
|
|
||||||
|
# Per-guild, per-user spam tracking
|
||||||
|
# Structure: {guild_id: {user_id: UserSpamTracker}}
|
||||||
|
self._spam_trackers: dict[int, dict[int, UserSpamTracker]] = defaultdict(
|
||||||
|
lambda: defaultdict(UserSpamTracker)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Spam thresholds
|
||||||
|
self.message_rate_limit = 5 # messages per window
|
||||||
|
self.message_rate_window = 5 # seconds
|
||||||
|
self.duplicate_threshold = 3 # same message count
|
||||||
|
self.mention_limit = 5 # mentions per message
|
||||||
|
self.mention_rate_limit = 10 # mentions per window
|
||||||
|
self.mention_rate_window = 60 # seconds
|
||||||
|
|
||||||
|
def _get_content_hash(self, content: str) -> str:
|
||||||
|
"""Get a normalized hash of message content for duplicate detection."""
|
||||||
|
# Normalize: lowercase, remove extra spaces, remove special chars
|
||||||
|
normalized = re.sub(r"[^\w\s]", "", content.lower())
|
||||||
|
normalized = re.sub(r"\s+", " ", normalized).strip()
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
def check_banned_words(
|
||||||
|
self, content: str, banned_words: list[BannedWord]
|
||||||
|
) -> AutomodResult | None:
|
||||||
|
"""Check message against banned words list."""
|
||||||
|
content_lower = content.lower()
|
||||||
|
|
||||||
|
for banned in banned_words:
|
||||||
|
matched = False
|
||||||
|
|
||||||
|
if banned.is_regex:
|
||||||
|
try:
|
||||||
|
if re.search(banned.pattern, content, re.IGNORECASE):
|
||||||
|
matched = True
|
||||||
|
except re.error:
|
||||||
|
logger.warning(f"Invalid regex pattern: {banned.pattern}")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if banned.pattern.lower() in content_lower:
|
||||||
|
matched = True
|
||||||
|
|
||||||
|
if matched:
|
||||||
|
result = AutomodResult(
|
||||||
|
should_delete=True,
|
||||||
|
reason=banned.reason or f"Matched banned word filter",
|
||||||
|
matched_filter=f"banned_word:{banned.id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if banned.action == "warn":
|
||||||
|
result.should_warn = True
|
||||||
|
elif banned.action == "strike":
|
||||||
|
result.should_strike = True
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def check_scam_links(self, content: str) -> AutomodResult | None:
|
||||||
|
"""Check message for scam/phishing patterns."""
|
||||||
|
# Check for known scam patterns
|
||||||
|
for pattern in self._scam_patterns:
|
||||||
|
if pattern.search(content):
|
||||||
|
return AutomodResult(
|
||||||
|
should_delete=True,
|
||||||
|
should_warn=True,
|
||||||
|
reason="Message matched known scam/phishing pattern",
|
||||||
|
matched_filter="scam_pattern",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check URLs for suspicious TLDs
|
||||||
|
urls = URL_PATTERN.findall(content)
|
||||||
|
for url in urls:
|
||||||
|
url_lower = url.lower()
|
||||||
|
for tld in SUSPICIOUS_TLDS:
|
||||||
|
if tld in url_lower:
|
||||||
|
# Additional check: is it trying to impersonate a known domain?
|
||||||
|
impersonation_keywords = [
|
||||||
|
"discord",
|
||||||
|
"steam",
|
||||||
|
"nitro",
|
||||||
|
"gift",
|
||||||
|
"free",
|
||||||
|
"login",
|
||||||
|
"verify",
|
||||||
|
]
|
||||||
|
if any(kw in url_lower for kw in impersonation_keywords):
|
||||||
|
return AutomodResult(
|
||||||
|
should_delete=True,
|
||||||
|
should_warn=True,
|
||||||
|
reason=f"Suspicious link detected: {url[:50]}",
|
||||||
|
matched_filter="suspicious_link",
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def check_spam(
|
||||||
|
self, message: discord.Message, anti_spam_enabled: bool = True
|
||||||
|
) -> AutomodResult | None:
|
||||||
|
"""Check message for spam behavior."""
|
||||||
|
if not anti_spam_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
guild_id = message.guild.id
|
||||||
|
user_id = message.author.id
|
||||||
|
tracker = self._spam_trackers[guild_id][user_id]
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Cleanup old records
|
||||||
|
tracker.cleanup()
|
||||||
|
|
||||||
|
# Check message rate
|
||||||
|
content_hash = self._get_content_hash(message.content)
|
||||||
|
tracker.messages.append(SpamRecord(content_hash, now))
|
||||||
|
|
||||||
|
# Rate limit check
|
||||||
|
recent_window = now - timedelta(seconds=self.message_rate_window)
|
||||||
|
recent_messages = [m for m in tracker.messages if m.timestamp > recent_window]
|
||||||
|
|
||||||
|
if len(recent_messages) > self.message_rate_limit:
|
||||||
|
return AutomodResult(
|
||||||
|
should_delete=True,
|
||||||
|
should_timeout=True,
|
||||||
|
timeout_duration=60, # 1 minute timeout
|
||||||
|
reason=f"Sending messages too fast ({len(recent_messages)} in {self.message_rate_window}s)",
|
||||||
|
matched_filter="rate_limit",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Duplicate message check
|
||||||
|
duplicate_count = sum(1 for m in tracker.messages if m.content_hash == content_hash)
|
||||||
|
if duplicate_count >= self.duplicate_threshold:
|
||||||
|
return AutomodResult(
|
||||||
|
should_delete=True,
|
||||||
|
should_warn=True,
|
||||||
|
reason=f"Duplicate message detected ({duplicate_count} times)",
|
||||||
|
matched_filter="duplicate",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mass mention check
|
||||||
|
mention_count = len(message.mentions) + len(message.role_mentions)
|
||||||
|
if message.mention_everyone:
|
||||||
|
mention_count += 100 # Treat @everyone as many mentions
|
||||||
|
|
||||||
|
if mention_count > self.mention_limit:
|
||||||
|
return AutomodResult(
|
||||||
|
should_delete=True,
|
||||||
|
should_timeout=True,
|
||||||
|
timeout_duration=300, # 5 minute timeout
|
||||||
|
reason=f"Mass mentions detected ({mention_count} mentions)",
|
||||||
|
matched_filter="mass_mention",
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def check_invite_links(self, content: str, allow_invites: bool = True) -> AutomodResult | None:
|
||||||
|
"""Check for Discord invite links."""
|
||||||
|
if allow_invites:
|
||||||
|
return None
|
||||||
|
|
||||||
|
invite_pattern = re.compile(
|
||||||
|
r"(?:https?://)?(?:www\.)?(?:discord\.(?:gg|io|me|li)|discordapp\.com/invite)/[\w-]+",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
if invite_pattern.search(content):
|
||||||
|
return AutomodResult(
|
||||||
|
should_delete=True,
|
||||||
|
reason="Discord invite links are not allowed",
|
||||||
|
matched_filter="invite_link",
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def check_all_caps(
|
||||||
|
self, content: str, threshold: float = 0.7, min_length: int = 10
|
||||||
|
) -> AutomodResult | None:
|
||||||
|
"""Check for excessive caps usage."""
|
||||||
|
# Only check messages with enough letters
|
||||||
|
letters = [c for c in content if c.isalpha()]
|
||||||
|
if len(letters) < min_length:
|
||||||
|
return None
|
||||||
|
|
||||||
|
caps_count = sum(1 for c in letters if c.isupper())
|
||||||
|
caps_ratio = caps_count / len(letters)
|
||||||
|
|
||||||
|
if caps_ratio > threshold:
|
||||||
|
return AutomodResult(
|
||||||
|
should_delete=True,
|
||||||
|
reason="Excessive caps usage",
|
||||||
|
matched_filter="caps",
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def reset_user_tracker(self, guild_id: int, user_id: int) -> None:
|
||||||
|
"""Reset spam tracking for a user."""
|
||||||
|
if guild_id in self._spam_trackers:
|
||||||
|
self._spam_trackers[guild_id].pop(user_id, None)
|
||||||
|
|
||||||
|
def cleanup_guild(self, guild_id: int) -> None:
|
||||||
|
"""Remove all tracking data for a guild."""
|
||||||
|
self._spam_trackers.pop(guild_id, None)
|
||||||
99
src/guardden/services/database.py
Normal file
99
src/guardden/services/database.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""Database connection and session management."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from guardden.config import Settings
|
||||||
|
from guardden.models import Base
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
"""Manages database connections and sessions."""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings) -> None:
|
||||||
|
self.settings = settings
|
||||||
|
self._engine = None
|
||||||
|
self._session_factory = None
|
||||||
|
self._pool: asyncpg.Pool | None = None
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""Initialize database connection pool."""
|
||||||
|
db_url = self.settings.database_url.get_secret_value()
|
||||||
|
|
||||||
|
# Create SQLAlchemy async engine
|
||||||
|
# Convert postgresql:// to postgresql+asyncpg://
|
||||||
|
if db_url.startswith("postgresql://"):
|
||||||
|
sqlalchemy_url = db_url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||||
|
else:
|
||||||
|
sqlalchemy_url = db_url
|
||||||
|
|
||||||
|
self._engine = create_async_engine(
|
||||||
|
sqlalchemy_url,
|
||||||
|
pool_size=self.settings.database_pool_min,
|
||||||
|
max_overflow=self.settings.database_pool_max - self.settings.database_pool_min,
|
||||||
|
echo=self.settings.log_level == "DEBUG",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._session_factory = async_sessionmaker(
|
||||||
|
self._engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Also create a raw asyncpg pool for performance-critical operations
|
||||||
|
self._pool = await asyncpg.create_pool(
|
||||||
|
db_url,
|
||||||
|
min_size=self.settings.database_pool_min,
|
||||||
|
max_size=self.settings.database_pool_max,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Database connection established")
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Close all database connections."""
|
||||||
|
if self._pool:
|
||||||
|
await self._pool.close()
|
||||||
|
self._pool = None
|
||||||
|
|
||||||
|
if self._engine:
|
||||||
|
await self._engine.dispose()
|
||||||
|
self._engine = None
|
||||||
|
|
||||||
|
logger.info("Database connections closed")
|
||||||
|
|
||||||
|
async def create_tables(self) -> None:
|
||||||
|
"""Create all database tables."""
|
||||||
|
if not self._engine:
|
||||||
|
raise RuntimeError("Database not connected")
|
||||||
|
|
||||||
|
async with self._engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
logger.info("Database tables created")
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Get a database session context manager."""
|
||||||
|
if not self._session_factory:
|
||||||
|
raise RuntimeError("Database not connected")
|
||||||
|
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pool(self) -> asyncpg.Pool:
|
||||||
|
"""Get the raw asyncpg connection pool."""
|
||||||
|
if not self._pool:
|
||||||
|
raise RuntimeError("Database not connected")
|
||||||
|
return self._pool
|
||||||
145
src/guardden/services/guild_config.py
Normal file
145
src/guardden/services/guild_config.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""Guild configuration service."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from guardden.models import BannedWord, Guild, GuildSettings
|
||||||
|
from guardden.services.database import Database
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GuildConfigService:
|
||||||
|
"""Manages guild configurations with caching."""
|
||||||
|
|
||||||
|
def __init__(self, database: Database) -> None:
|
||||||
|
self.database = database
|
||||||
|
self._cache: dict[int, GuildSettings] = {}
|
||||||
|
|
||||||
|
async def get_config(self, guild_id: int) -> GuildSettings | None:
|
||||||
|
"""Get guild configuration, using cache if available."""
|
||||||
|
if guild_id in self._cache:
|
||||||
|
return self._cache[guild_id]
|
||||||
|
|
||||||
|
async with self.database.session() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(GuildSettings).where(GuildSettings.guild_id == guild_id)
|
||||||
|
)
|
||||||
|
settings = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if settings:
|
||||||
|
self._cache[guild_id] = settings
|
||||||
|
|
||||||
|
return settings
|
||||||
|
|
||||||
|
async def get_guild(self, guild_id: int) -> Guild | None:
|
||||||
|
"""Get full guild data including settings and banned words."""
|
||||||
|
async with self.database.session() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(Guild)
|
||||||
|
.options(selectinload(Guild.settings), selectinload(Guild.banned_words))
|
||||||
|
.where(Guild.id == guild_id)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def create_guild(self, guild: discord.Guild) -> Guild:
|
||||||
|
"""Create a new guild entry with default settings."""
|
||||||
|
async with self.database.session() as session:
|
||||||
|
# Check if guild already exists
|
||||||
|
existing = await session.get(Guild, guild.id)
|
||||||
|
if existing:
|
||||||
|
return existing
|
||||||
|
|
||||||
|
# Create new guild
|
||||||
|
db_guild = Guild(
|
||||||
|
id=guild.id,
|
||||||
|
name=guild.name,
|
||||||
|
owner_id=guild.owner_id,
|
||||||
|
)
|
||||||
|
session.add(db_guild)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
# Create default settings
|
||||||
|
settings = GuildSettings(guild_id=guild.id)
|
||||||
|
session.add(settings)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info(f"Created guild config for {guild.name} (ID: {guild.id})")
|
||||||
|
return db_guild
|
||||||
|
|
||||||
|
async def update_settings(self, guild_id: int, **kwargs) -> GuildSettings | None:
|
||||||
|
"""Update guild settings."""
|
||||||
|
async with self.database.session() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(GuildSettings).where(GuildSettings.guild_id == guild_id)
|
||||||
|
)
|
||||||
|
settings = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not settings:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(settings, key):
|
||||||
|
setattr(settings, key, value)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Invalidate cache
|
||||||
|
self._cache.pop(guild_id, None)
|
||||||
|
|
||||||
|
return settings
|
||||||
|
|
||||||
|
def invalidate_cache(self, guild_id: int) -> None:
|
||||||
|
"""Remove a guild from the cache."""
|
||||||
|
self._cache.pop(guild_id, None)
|
||||||
|
|
||||||
|
async def get_banned_words(self, guild_id: int) -> list[BannedWord]:
|
||||||
|
"""Get all banned words for a guild."""
|
||||||
|
async with self.database.session() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(BannedWord).where(BannedWord.guild_id == guild_id)
|
||||||
|
)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def add_banned_word(
|
||||||
|
self,
|
||||||
|
guild_id: int,
|
||||||
|
pattern: str,
|
||||||
|
added_by: int,
|
||||||
|
is_regex: bool = False,
|
||||||
|
action: str = "delete",
|
||||||
|
reason: str | None = None,
|
||||||
|
) -> BannedWord:
|
||||||
|
"""Add a banned word to a guild."""
|
||||||
|
async with self.database.session() as session:
|
||||||
|
banned_word = BannedWord(
|
||||||
|
guild_id=guild_id,
|
||||||
|
pattern=pattern,
|
||||||
|
is_regex=is_regex,
|
||||||
|
action=action,
|
||||||
|
reason=reason,
|
||||||
|
added_by=added_by,
|
||||||
|
)
|
||||||
|
session.add(banned_word)
|
||||||
|
await session.commit()
|
||||||
|
return banned_word
|
||||||
|
|
||||||
|
async def remove_banned_word(self, guild_id: int, word_id: int) -> bool:
|
||||||
|
"""Remove a banned word from a guild."""
|
||||||
|
async with self.database.session() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(BannedWord).where(BannedWord.id == word_id, BannedWord.guild_id == guild_id)
|
||||||
|
)
|
||||||
|
word = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if word:
|
||||||
|
session.delete(word)
|
||||||
|
await session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
300
src/guardden/services/ratelimit.py
Normal file
300
src/guardden/services/ratelimit.py
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
"""Rate limiting service for command and action throttling."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitScope(str, Enum):
|
||||||
|
"""Scope of rate limiting."""
|
||||||
|
|
||||||
|
USER = "user" # Per user globally
|
||||||
|
MEMBER = "member" # Per user per guild
|
||||||
|
CHANNEL = "channel" # Per channel
|
||||||
|
GUILD = "guild" # Per guild
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RateLimitBucket:
|
||||||
|
"""Tracks rate limit state for a single bucket."""
|
||||||
|
|
||||||
|
max_requests: int
|
||||||
|
window_seconds: float
|
||||||
|
requests: list[datetime] = field(default_factory=list)
|
||||||
|
|
||||||
|
def cleanup(self) -> None:
|
||||||
|
"""Remove expired requests from tracking."""
|
||||||
|
cutoff = datetime.now(timezone.utc) - timedelta(seconds=self.window_seconds)
|
||||||
|
self.requests = [r for r in self.requests if r > cutoff]
|
||||||
|
|
||||||
|
def is_limited(self) -> bool:
|
||||||
|
"""Check if this bucket is rate limited."""
|
||||||
|
self.cleanup()
|
||||||
|
return len(self.requests) >= self.max_requests
|
||||||
|
|
||||||
|
def record(self) -> None:
|
||||||
|
"""Record a request."""
|
||||||
|
self.requests.append(datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
def remaining(self) -> int:
|
||||||
|
"""Get remaining requests in current window."""
|
||||||
|
self.cleanup()
|
||||||
|
return max(0, self.max_requests - len(self.requests))
|
||||||
|
|
||||||
|
def reset_after(self) -> float:
|
||||||
|
"""Get seconds until rate limit resets."""
|
||||||
|
if not self.requests:
|
||||||
|
return 0.0
|
||||||
|
self.cleanup()
|
||||||
|
if not self.requests:
|
||||||
|
return 0.0
|
||||||
|
oldest = min(self.requests)
|
||||||
|
reset_time = oldest + timedelta(seconds=self.window_seconds)
|
||||||
|
remaining = (reset_time - datetime.now(timezone.utc)).total_seconds()
|
||||||
|
return max(0.0, remaining)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RateLimitConfig:
|
||||||
|
"""Configuration for a rate limit."""
|
||||||
|
|
||||||
|
max_requests: int
|
||||||
|
window_seconds: float
|
||||||
|
scope: RateLimitScope = RateLimitScope.MEMBER
|
||||||
|
|
||||||
|
def create_bucket(self) -> RateLimitBucket:
|
||||||
|
return RateLimitBucket(
|
||||||
|
max_requests=self.max_requests,
|
||||||
|
window_seconds=self.window_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RateLimitResult:
|
||||||
|
"""Result of a rate limit check."""
|
||||||
|
|
||||||
|
is_limited: bool
|
||||||
|
remaining: int
|
||||||
|
reset_after: float
|
||||||
|
bucket_key: str
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""General-purpose rate limiter."""
|
||||||
|
|
||||||
|
# Default rate limits for various actions
|
||||||
|
DEFAULT_LIMITS = {
|
||||||
|
"command": RateLimitConfig(5, 10, RateLimitScope.MEMBER), # 5 commands per 10s
|
||||||
|
"moderation": RateLimitConfig(10, 60, RateLimitScope.MEMBER), # 10 mod actions per minute
|
||||||
|
"verification": RateLimitConfig(3, 300, RateLimitScope.MEMBER), # 3 verifications per 5 min
|
||||||
|
"message": RateLimitConfig(10, 10, RateLimitScope.MEMBER), # 10 messages per 10s
|
||||||
|
"api_call": RateLimitConfig(
|
||||||
|
30, 60, RateLimitScope.GUILD
|
||||||
|
), # 30 API calls per minute per guild
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# Buckets: {action: {bucket_key: RateLimitBucket}}
|
||||||
|
self._buckets: dict[str, dict[str, RateLimitBucket]] = defaultdict(dict)
|
||||||
|
self._configs: dict[str, RateLimitConfig] = dict(self.DEFAULT_LIMITS)
|
||||||
|
|
||||||
|
def configure(self, action: str, config: RateLimitConfig) -> None:
|
||||||
|
"""Configure rate limit for an action."""
|
||||||
|
self._configs[action] = config
|
||||||
|
# Clear existing buckets for this action
|
||||||
|
self._buckets[action].clear()
|
||||||
|
|
||||||
|
def _get_bucket_key(
|
||||||
|
self,
|
||||||
|
scope: RateLimitScope,
|
||||||
|
user_id: int | None = None,
|
||||||
|
guild_id: int | None = None,
|
||||||
|
channel_id: int | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Generate a bucket key based on scope."""
|
||||||
|
if scope == RateLimitScope.USER:
|
||||||
|
return f"user:{user_id}"
|
||||||
|
elif scope == RateLimitScope.MEMBER:
|
||||||
|
return f"member:{guild_id}:{user_id}"
|
||||||
|
elif scope == RateLimitScope.CHANNEL:
|
||||||
|
return f"channel:{channel_id}"
|
||||||
|
elif scope == RateLimitScope.GUILD:
|
||||||
|
return f"guild:{guild_id}"
|
||||||
|
return f"unknown:{user_id}:{guild_id}"
|
||||||
|
|
||||||
|
def check(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
user_id: int | None = None,
|
||||||
|
guild_id: int | None = None,
|
||||||
|
channel_id: int | None = None,
|
||||||
|
) -> RateLimitResult:
|
||||||
|
"""
|
||||||
|
Check if an action is rate limited.
|
||||||
|
|
||||||
|
Does not record the request - use `acquire()` for that.
|
||||||
|
"""
|
||||||
|
config = self._configs.get(action)
|
||||||
|
if not config:
|
||||||
|
return RateLimitResult(
|
||||||
|
is_limited=False,
|
||||||
|
remaining=999,
|
||||||
|
reset_after=0,
|
||||||
|
bucket_key="",
|
||||||
|
)
|
||||||
|
|
||||||
|
bucket_key = self._get_bucket_key(config.scope, user_id, guild_id, channel_id)
|
||||||
|
bucket = self._buckets[action].get(bucket_key)
|
||||||
|
|
||||||
|
if not bucket:
|
||||||
|
return RateLimitResult(
|
||||||
|
is_limited=False,
|
||||||
|
remaining=config.max_requests,
|
||||||
|
reset_after=0,
|
||||||
|
bucket_key=bucket_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
return RateLimitResult(
|
||||||
|
is_limited=bucket.is_limited(),
|
||||||
|
remaining=bucket.remaining(),
|
||||||
|
reset_after=bucket.reset_after(),
|
||||||
|
bucket_key=bucket_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
def acquire(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
user_id: int | None = None,
|
||||||
|
guild_id: int | None = None,
|
||||||
|
channel_id: int | None = None,
|
||||||
|
) -> RateLimitResult:
|
||||||
|
"""
|
||||||
|
Attempt to acquire a rate limit slot.
|
||||||
|
|
||||||
|
Records the request if not limited.
|
||||||
|
"""
|
||||||
|
config = self._configs.get(action)
|
||||||
|
if not config:
|
||||||
|
return RateLimitResult(
|
||||||
|
is_limited=False,
|
||||||
|
remaining=999,
|
||||||
|
reset_after=0,
|
||||||
|
bucket_key="",
|
||||||
|
)
|
||||||
|
|
||||||
|
bucket_key = self._get_bucket_key(config.scope, user_id, guild_id, channel_id)
|
||||||
|
|
||||||
|
if bucket_key not in self._buckets[action]:
|
||||||
|
self._buckets[action][bucket_key] = config.create_bucket()
|
||||||
|
|
||||||
|
bucket = self._buckets[action][bucket_key]
|
||||||
|
|
||||||
|
if bucket.is_limited():
|
||||||
|
return RateLimitResult(
|
||||||
|
is_limited=True,
|
||||||
|
remaining=0,
|
||||||
|
reset_after=bucket.reset_after(),
|
||||||
|
bucket_key=bucket_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
bucket.record()
|
||||||
|
|
||||||
|
return RateLimitResult(
|
||||||
|
is_limited=False,
|
||||||
|
remaining=bucket.remaining(),
|
||||||
|
reset_after=bucket.reset_after(),
|
||||||
|
bucket_key=bucket_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
user_id: int | None = None,
|
||||||
|
guild_id: int | None = None,
|
||||||
|
channel_id: int | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Reset rate limit for a specific bucket."""
|
||||||
|
config = self._configs.get(action)
|
||||||
|
if not config:
|
||||||
|
return False
|
||||||
|
|
||||||
|
bucket_key = self._get_bucket_key(config.scope, user_id, guild_id, channel_id)
|
||||||
|
return self._buckets[action].pop(bucket_key, None) is not None
|
||||||
|
|
||||||
|
def cleanup(self) -> int:
|
||||||
|
"""Clean up empty and expired buckets. Returns count removed."""
|
||||||
|
removed = 0
|
||||||
|
for action in list(self._buckets.keys()):
|
||||||
|
for key in list(self._buckets[action].keys()):
|
||||||
|
bucket = self._buckets[action][key]
|
||||||
|
bucket.cleanup()
|
||||||
|
if not bucket.requests:
|
||||||
|
del self._buckets[action][key]
|
||||||
|
removed += 1
|
||||||
|
return removed
|
||||||
|
|
||||||
|
|
||||||
|
# Global rate limiter instance
|
||||||
|
_rate_limiter: RateLimiter | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_rate_limiter() -> RateLimiter:
|
||||||
|
"""Get or create the global rate limiter instance."""
|
||||||
|
global _rate_limiter
|
||||||
|
if _rate_limiter is None:
|
||||||
|
_rate_limiter = RateLimiter()
|
||||||
|
return _rate_limiter
|
||||||
|
|
||||||
|
|
||||||
|
def ratelimit(
|
||||||
|
action: str = "command",
|
||||||
|
max_requests: int | None = None,
|
||||||
|
window_seconds: float | None = None,
|
||||||
|
) -> Callable:
|
||||||
|
"""
|
||||||
|
Decorator for rate limiting commands.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@ratelimit("moderation", max_requests=5, window_seconds=60)
|
||||||
|
async def my_command(self, ctx):
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
async def wrapper(self, ctx, *args, **kwargs):
|
||||||
|
limiter = get_rate_limiter()
|
||||||
|
|
||||||
|
# Configure if custom limits provided
|
||||||
|
if max_requests is not None and window_seconds is not None:
|
||||||
|
limiter.configure(
|
||||||
|
action,
|
||||||
|
RateLimitConfig(max_requests, window_seconds, RateLimitScope.MEMBER),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = limiter.acquire(
|
||||||
|
action,
|
||||||
|
user_id=ctx.author.id,
|
||||||
|
guild_id=ctx.guild.id if ctx.guild else None,
|
||||||
|
channel_id=ctx.channel.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.is_limited:
|
||||||
|
await ctx.send(
|
||||||
|
f"You're being rate limited. Try again in {result.reset_after:.1f} seconds.",
|
||||||
|
delete_after=5,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
return await func(self, ctx, *args, **kwargs)
|
||||||
|
|
||||||
|
# Preserve function metadata
|
||||||
|
wrapper.__name__ = func.__name__
|
||||||
|
wrapper.__doc__ = func.__doc__
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
300
src/guardden/services/verification.py
Normal file
300
src/guardden/services/verification.py
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
"""Verification service for new member challenges."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import discord
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ChallengeType(str, Enum):
|
||||||
|
"""Types of verification challenges."""
|
||||||
|
|
||||||
|
BUTTON = "button" # Simple button click
|
||||||
|
CAPTCHA = "captcha" # Text-based captcha
|
||||||
|
MATH = "math" # Simple math problem
|
||||||
|
EMOJI = "emoji" # Select correct emoji
|
||||||
|
QUESTIONS = "questions" # Custom questions
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Challenge:
|
||||||
|
"""Represents a verification challenge."""
|
||||||
|
|
||||||
|
challenge_type: ChallengeType
|
||||||
|
question: str
|
||||||
|
answer: str
|
||||||
|
options: list[str] = field(default_factory=list) # For multiple choice
|
||||||
|
expires_at: datetime = field(
|
||||||
|
default_factory=lambda: datetime.now(timezone.utc) + timedelta(minutes=10)
|
||||||
|
)
|
||||||
|
attempts: int = 0
|
||||||
|
max_attempts: int = 3
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
return datetime.now(timezone.utc) > self.expires_at
|
||||||
|
|
||||||
|
def check_answer(self, response: str) -> bool:
|
||||||
|
"""Check if the response is correct."""
|
||||||
|
self.attempts += 1
|
||||||
|
return response.strip().lower() == self.answer.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PendingVerification:
|
||||||
|
"""Tracks a pending verification for a user."""
|
||||||
|
|
||||||
|
user_id: int
|
||||||
|
guild_id: int
|
||||||
|
challenge: Challenge
|
||||||
|
message_id: int | None = None
|
||||||
|
channel_id: int | None = None
|
||||||
|
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
|
||||||
|
class ChallengeGenerator(ABC):
|
||||||
|
"""Abstract base class for challenge generators."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(self) -> Challenge:
|
||||||
|
"""Generate a new challenge."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ButtonChallengeGenerator(ChallengeGenerator):
|
||||||
|
"""Generates simple button click challenges."""
|
||||||
|
|
||||||
|
def generate(self) -> Challenge:
|
||||||
|
return Challenge(
|
||||||
|
challenge_type=ChallengeType.BUTTON,
|
||||||
|
question="Click the button below to verify you're human.",
|
||||||
|
answer="verified",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CaptchaChallengeGenerator(ChallengeGenerator):
|
||||||
|
"""Generates text-based captcha challenges."""
|
||||||
|
|
||||||
|
def __init__(self, length: int = 6) -> None:
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def generate(self) -> Challenge:
|
||||||
|
# Generate random alphanumeric code (avoiding confusing chars)
|
||||||
|
chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
||||||
|
code = "".join(random.choices(chars, k=self.length))
|
||||||
|
|
||||||
|
# Create visual representation with some obfuscation
|
||||||
|
visual = self._create_visual(code)
|
||||||
|
|
||||||
|
return Challenge(
|
||||||
|
challenge_type=ChallengeType.CAPTCHA,
|
||||||
|
question=f"Enter the code shown below:\n```\n{visual}\n```",
|
||||||
|
answer=code,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_visual(self, code: str) -> str:
|
||||||
|
"""Create a simple text-based visual captcha."""
|
||||||
|
lines = []
|
||||||
|
# Add some noise characters
|
||||||
|
noise_chars = ".-*~^"
|
||||||
|
|
||||||
|
for _ in range(2):
|
||||||
|
lines.append("".join(random.choices(noise_chars, k=len(code) * 2)))
|
||||||
|
|
||||||
|
# Add the code with spacing
|
||||||
|
spaced = " ".join(code)
|
||||||
|
lines.append(spaced)
|
||||||
|
|
||||||
|
for _ in range(2):
|
||||||
|
lines.append("".join(random.choices(noise_chars, k=len(code) * 2)))
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
class MathChallengeGenerator(ChallengeGenerator):
|
||||||
|
"""Generates simple math problem challenges."""
|
||||||
|
|
||||||
|
def generate(self) -> Challenge:
|
||||||
|
# Generate simple addition/subtraction/multiplication
|
||||||
|
operation = random.choice(["+", "-", "*"])
|
||||||
|
|
||||||
|
if operation == "*":
|
||||||
|
a = random.randint(2, 10)
|
||||||
|
b = random.randint(2, 10)
|
||||||
|
else:
|
||||||
|
a = random.randint(10, 50)
|
||||||
|
b = random.randint(1, 20)
|
||||||
|
|
||||||
|
if operation == "+":
|
||||||
|
answer = a + b
|
||||||
|
elif operation == "-":
|
||||||
|
# Ensure positive result
|
||||||
|
if b > a:
|
||||||
|
a, b = b, a
|
||||||
|
answer = a - b
|
||||||
|
else:
|
||||||
|
answer = a * b
|
||||||
|
|
||||||
|
return Challenge(
|
||||||
|
challenge_type=ChallengeType.MATH,
|
||||||
|
question=f"Solve this math problem: **{a} {operation} {b} = ?**",
|
||||||
|
answer=str(answer),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmojiChallengeGenerator(ChallengeGenerator):
|
||||||
|
"""Generates emoji selection challenges."""
|
||||||
|
|
||||||
|
EMOJI_SETS = [
|
||||||
|
("animals", ["🐶", "🐱", "🐭", "🐹", "🐰", "🦊", "🐻", "🐼"]),
|
||||||
|
("fruits", ["🍎", "🍐", "🍊", "🍋", "🍌", "🍉", "🍇", "🍓"]),
|
||||||
|
("weather", ["☀️", "🌙", "⭐", "🌧️", "❄️", "🌈", "⚡", "🌪️"]),
|
||||||
|
("sports", ["⚽", "🏀", "🏈", "⚾", "🎾", "🏐", "🏉", "🎱"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
def generate(self) -> Challenge:
|
||||||
|
category, emojis = random.choice(self.EMOJI_SETS)
|
||||||
|
target = random.choice(emojis)
|
||||||
|
|
||||||
|
# Create options with the target and some others
|
||||||
|
options = [target]
|
||||||
|
other_emojis = [e for e in emojis if e != target]
|
||||||
|
options.extend(random.sample(other_emojis, min(3, len(other_emojis))))
|
||||||
|
random.shuffle(options)
|
||||||
|
|
||||||
|
return Challenge(
|
||||||
|
challenge_type=ChallengeType.EMOJI,
|
||||||
|
question=f"Select the {self._emoji_name(target)} emoji:",
|
||||||
|
answer=target,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _emoji_name(self, emoji: str) -> str:
|
||||||
|
"""Get a description of the emoji."""
|
||||||
|
names = {
|
||||||
|
"🐶": "dog",
|
||||||
|
"🐱": "cat",
|
||||||
|
"🐭": "mouse",
|
||||||
|
"🐹": "hamster",
|
||||||
|
"🐰": "rabbit",
|
||||||
|
"🦊": "fox",
|
||||||
|
"🐻": "bear",
|
||||||
|
"🐼": "panda",
|
||||||
|
"🍎": "apple",
|
||||||
|
"🍐": "pear",
|
||||||
|
"🍊": "orange",
|
||||||
|
"🍋": "lemon",
|
||||||
|
"🍌": "banana",
|
||||||
|
"🍉": "watermelon",
|
||||||
|
"🍇": "grapes",
|
||||||
|
"🍓": "strawberry",
|
||||||
|
"☀️": "sun",
|
||||||
|
"🌙": "moon",
|
||||||
|
"⭐": "star",
|
||||||
|
"🌧️": "rain",
|
||||||
|
"❄️": "snowflake",
|
||||||
|
"🌈": "rainbow",
|
||||||
|
"⚡": "lightning",
|
||||||
|
"🌪️": "tornado",
|
||||||
|
"⚽": "soccer ball",
|
||||||
|
"🏀": "basketball",
|
||||||
|
"🏈": "football",
|
||||||
|
"⚾": "baseball",
|
||||||
|
"🎾": "tennis",
|
||||||
|
"🏐": "volleyball",
|
||||||
|
"🏉": "rugby",
|
||||||
|
"🎱": "pool ball",
|
||||||
|
}
|
||||||
|
return names.get(emoji, "correct")
|
||||||
|
|
||||||
|
|
||||||
|
class VerificationService:
|
||||||
|
"""Service for managing member verification."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# Pending verifications: {(guild_id, user_id): PendingVerification}
|
||||||
|
self._pending: dict[tuple[int, int], PendingVerification] = {}
|
||||||
|
|
||||||
|
# Challenge generators
|
||||||
|
self._generators: dict[ChallengeType, ChallengeGenerator] = {
|
||||||
|
ChallengeType.BUTTON: ButtonChallengeGenerator(),
|
||||||
|
ChallengeType.CAPTCHA: CaptchaChallengeGenerator(),
|
||||||
|
ChallengeType.MATH: MathChallengeGenerator(),
|
||||||
|
ChallengeType.EMOJI: EmojiChallengeGenerator(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_challenge(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
guild_id: int,
|
||||||
|
challenge_type: ChallengeType = ChallengeType.BUTTON,
|
||||||
|
) -> PendingVerification:
|
||||||
|
"""Create a new verification challenge for a user."""
|
||||||
|
generator = self._generators.get(challenge_type)
|
||||||
|
if not generator:
|
||||||
|
generator = self._generators[ChallengeType.BUTTON]
|
||||||
|
|
||||||
|
challenge = generator.generate()
|
||||||
|
pending = PendingVerification(
|
||||||
|
user_id=user_id,
|
||||||
|
guild_id=guild_id,
|
||||||
|
challenge=challenge,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._pending[(guild_id, user_id)] = pending
|
||||||
|
return pending
|
||||||
|
|
||||||
|
def get_pending(self, guild_id: int, user_id: int) -> PendingVerification | None:
|
||||||
|
"""Get a pending verification for a user."""
|
||||||
|
return self._pending.get((guild_id, user_id))
|
||||||
|
|
||||||
|
def verify(self, guild_id: int, user_id: int, response: str) -> tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Attempt to verify a user's response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (success, message)
|
||||||
|
"""
|
||||||
|
pending = self._pending.get((guild_id, user_id))
|
||||||
|
|
||||||
|
if not pending:
|
||||||
|
return False, "No pending verification found."
|
||||||
|
|
||||||
|
if pending.challenge.is_expired:
|
||||||
|
self._pending.pop((guild_id, user_id), None)
|
||||||
|
return False, "Verification expired. Please request a new one."
|
||||||
|
|
||||||
|
if pending.challenge.attempts >= pending.challenge.max_attempts:
|
||||||
|
self._pending.pop((guild_id, user_id), None)
|
||||||
|
return False, "Too many failed attempts. Please request a new verification."
|
||||||
|
|
||||||
|
if pending.challenge.check_answer(response):
|
||||||
|
self._pending.pop((guild_id, user_id), None)
|
||||||
|
return True, "Verification successful!"
|
||||||
|
|
||||||
|
remaining = pending.challenge.max_attempts - pending.challenge.attempts
|
||||||
|
return False, f"Incorrect. {remaining} attempt(s) remaining."
|
||||||
|
|
||||||
|
def cancel(self, guild_id: int, user_id: int) -> bool:
|
||||||
|
"""Cancel a pending verification."""
|
||||||
|
return self._pending.pop((guild_id, user_id), None) is not None
|
||||||
|
|
||||||
|
def cleanup_expired(self) -> int:
|
||||||
|
"""Remove expired verifications. Returns count of removed."""
|
||||||
|
expired = [key for key, pending in self._pending.items() if pending.challenge.is_expired]
|
||||||
|
for key in expired:
|
||||||
|
self._pending.pop(key, None)
|
||||||
|
return len(expired)
|
||||||
|
|
||||||
|
def get_pending_count(self, guild_id: int) -> int:
|
||||||
|
"""Get count of pending verifications for a guild."""
|
||||||
|
return sum(1 for (gid, _) in self._pending if gid == guild_id)
|
||||||
5
src/guardden/utils/__init__.py
Normal file
5
src/guardden/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Utility functions for GuardDen."""
|
||||||
|
|
||||||
|
from guardden.utils.logging import setup_logging
|
||||||
|
|
||||||
|
__all__ = ["setup_logging"]
|
||||||
27
src/guardden/utils/logging.py
Normal file
27
src/guardden/utils/logging.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""Logging configuration for GuardDen."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO") -> None:
|
||||||
|
"""Configure logging for the application."""
|
||||||
|
log_format = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
|
||||||
|
date_format = "%Y-%m-%d %H:%M:%S"
|
||||||
|
|
||||||
|
# Configure root logger
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, level),
|
||||||
|
format=log_format,
|
||||||
|
datefmt=date_format,
|
||||||
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reduce noise from third-party libraries
|
||||||
|
logging.getLogger("discord").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("discord.http").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("asyncio").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("sqlalchemy.engine").setLevel(
|
||||||
|
logging.DEBUG if level == "DEBUG" else logging.WARNING
|
||||||
|
)
|
||||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for GuardDen."""
|
||||||
15
tests/conftest.py
Normal file
15
tests/conftest.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""Pytest fixtures for GuardDen tests."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_guild_id() -> int:
|
||||||
|
"""Return a sample Discord guild ID."""
|
||||||
|
return 123456789012345678
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_user_id() -> int:
|
||||||
|
"""Return a sample Discord user ID."""
|
||||||
|
return 987654321098765432
|
||||||
119
tests/test_ai.py
Normal file
119
tests/test_ai.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
"""Tests for AI services."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from guardden.services.ai.base import ContentCategory, ModerationResult
|
||||||
|
from guardden.services.ai.factory import NullProvider, create_ai_provider
|
||||||
|
|
||||||
|
|
||||||
|
class TestModerationResult:
|
||||||
|
"""Tests for ModerationResult dataclass."""
|
||||||
|
|
||||||
|
def test_severity_not_flagged(self) -> None:
|
||||||
|
"""Test severity is 0 when not flagged."""
|
||||||
|
result = ModerationResult(is_flagged=False, confidence=0.9)
|
||||||
|
assert result.severity == 0
|
||||||
|
|
||||||
|
def test_severity_with_confidence(self) -> None:
|
||||||
|
"""Test severity includes confidence."""
|
||||||
|
result = ModerationResult(
|
||||||
|
is_flagged=True,
|
||||||
|
confidence=0.8,
|
||||||
|
categories=[],
|
||||||
|
)
|
||||||
|
# 0.8 * 50 = 40
|
||||||
|
assert result.severity == 40
|
||||||
|
|
||||||
|
def test_severity_high_category(self) -> None:
|
||||||
|
"""Test severity with high-severity category."""
|
||||||
|
result = ModerationResult(
|
||||||
|
is_flagged=True,
|
||||||
|
confidence=0.5,
|
||||||
|
categories=[ContentCategory.HATE_SPEECH],
|
||||||
|
)
|
||||||
|
# 0.5 * 50 + 30 = 55
|
||||||
|
assert result.severity == 55
|
||||||
|
|
||||||
|
def test_severity_medium_category(self) -> None:
|
||||||
|
"""Test severity with medium-severity category."""
|
||||||
|
result = ModerationResult(
|
||||||
|
is_flagged=True,
|
||||||
|
confidence=0.5,
|
||||||
|
categories=[ContentCategory.HARASSMENT],
|
||||||
|
)
|
||||||
|
# 0.5 * 50 + 20 = 45
|
||||||
|
assert result.severity == 45
|
||||||
|
|
||||||
|
def test_severity_multiple_categories(self) -> None:
|
||||||
|
"""Test severity with multiple categories."""
|
||||||
|
result = ModerationResult(
|
||||||
|
is_flagged=True,
|
||||||
|
confidence=0.5,
|
||||||
|
categories=[ContentCategory.HATE_SPEECH, ContentCategory.VIOLENCE],
|
||||||
|
)
|
||||||
|
# 0.5 * 50 + 30 + 20 = 75
|
||||||
|
assert result.severity == 75
|
||||||
|
|
||||||
|
def test_severity_capped_at_100(self) -> None:
|
||||||
|
"""Test severity is capped at 100."""
|
||||||
|
result = ModerationResult(
|
||||||
|
is_flagged=True,
|
||||||
|
confidence=1.0,
|
||||||
|
categories=[
|
||||||
|
ContentCategory.HATE_SPEECH,
|
||||||
|
ContentCategory.SELF_HARM,
|
||||||
|
ContentCategory.SCAM,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# Would be 50 + 30 + 30 + 30 = 140, capped to 100
|
||||||
|
assert result.severity == 100
|
||||||
|
|
||||||
|
|
||||||
|
class TestNullProvider:
|
||||||
|
"""Tests for NullProvider."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def provider(self) -> NullProvider:
|
||||||
|
return NullProvider()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_moderate_text_returns_empty(self, provider: NullProvider) -> None:
|
||||||
|
"""Test moderate_text returns unflagged result."""
|
||||||
|
result = await provider.moderate_text("test content")
|
||||||
|
assert result.is_flagged is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analyze_image_returns_empty(self, provider: NullProvider) -> None:
|
||||||
|
"""Test analyze_image returns empty result."""
|
||||||
|
result = await provider.analyze_image("http://example.com/image.jpg")
|
||||||
|
assert result.is_nsfw is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analyze_phishing_returns_empty(self, provider: NullProvider) -> None:
|
||||||
|
"""Test analyze_phishing returns empty result."""
|
||||||
|
result = await provider.analyze_phishing("http://example.com")
|
||||||
|
assert result.is_phishing is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestFactory:
|
||||||
|
"""Tests for AI provider factory."""
|
||||||
|
|
||||||
|
def test_create_null_provider(self) -> None:
|
||||||
|
"""Test creating null provider."""
|
||||||
|
provider = create_ai_provider("none")
|
||||||
|
assert isinstance(provider, NullProvider)
|
||||||
|
|
||||||
|
def test_create_anthropic_without_key(self) -> None:
|
||||||
|
"""Test creating anthropic provider without key raises error."""
|
||||||
|
with pytest.raises(ValueError, match="API key required"):
|
||||||
|
create_ai_provider("anthropic", None)
|
||||||
|
|
||||||
|
def test_create_openai_without_key(self) -> None:
|
||||||
|
"""Test creating openai provider without key raises error."""
|
||||||
|
with pytest.raises(ValueError, match="API key required"):
|
||||||
|
create_ai_provider("openai", None)
|
||||||
|
|
||||||
|
def test_create_unknown_provider(self) -> None:
|
||||||
|
"""Test creating unknown provider raises error."""
|
||||||
|
with pytest.raises(ValueError, match="Unknown AI provider"):
|
||||||
|
create_ai_provider("unknown", "key") # type: ignore
|
||||||
153
tests/test_automod.py
Normal file
153
tests/test_automod.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""Tests for the automod service."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from guardden.models import BannedWord
|
||||||
|
from guardden.services.automod import AutomodService
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def automod() -> AutomodService:
|
||||||
|
"""Create an automod service instance."""
|
||||||
|
return AutomodService()
|
||||||
|
|
||||||
|
|
||||||
|
class TestBannedWords:
|
||||||
|
"""Tests for banned word filtering."""
|
||||||
|
|
||||||
|
def test_simple_match(self, automod: AutomodService) -> None:
|
||||||
|
"""Test simple text matching."""
|
||||||
|
banned = [_make_banned_word("badword")]
|
||||||
|
result = automod.check_banned_words("This contains badword in it", banned)
|
||||||
|
assert result is not None
|
||||||
|
assert result.should_delete
|
||||||
|
|
||||||
|
def test_case_insensitive(self, automod: AutomodService) -> None:
|
||||||
|
"""Test case insensitive matching."""
|
||||||
|
banned = [_make_banned_word("BadWord")]
|
||||||
|
result = automod.check_banned_words("this contains BADWORD here", banned)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
def test_no_match(self, automod: AutomodService) -> None:
|
||||||
|
"""Test no match returns None."""
|
||||||
|
banned = [_make_banned_word("badword")]
|
||||||
|
result = automod.check_banned_words("This is a clean message", banned)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_regex_pattern(self, automod: AutomodService) -> None:
|
||||||
|
"""Test regex pattern matching."""
|
||||||
|
banned = [_make_banned_word(r"bad\w+", is_regex=True)]
|
||||||
|
result = automod.check_banned_words("This is badword and badstuff", banned)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
def test_action_warn(self, automod: AutomodService) -> None:
|
||||||
|
"""Test warn action is set."""
|
||||||
|
banned = [_make_banned_word("badword", action="warn")]
|
||||||
|
result = automod.check_banned_words("badword", banned)
|
||||||
|
assert result is not None
|
||||||
|
assert result.should_warn
|
||||||
|
|
||||||
|
def test_action_strike(self, automod: AutomodService) -> None:
|
||||||
|
"""Test strike action is set."""
|
||||||
|
banned = [_make_banned_word("badword", action="strike")]
|
||||||
|
result = automod.check_banned_words("badword", banned)
|
||||||
|
assert result is not None
|
||||||
|
assert result.should_strike
|
||||||
|
|
||||||
|
|
||||||
|
class TestScamDetection:
|
||||||
|
"""Tests for scam/phishing detection."""
|
||||||
|
|
||||||
|
def test_discord_nitro_scam(self, automod: AutomodService) -> None:
|
||||||
|
"""Test detection of fake Discord Nitro links."""
|
||||||
|
result = automod.check_scam_links("Free nitro at discord-nitro.gift")
|
||||||
|
assert result is not None
|
||||||
|
assert result.should_delete
|
||||||
|
|
||||||
|
def test_steam_scam(self, automod: AutomodService) -> None:
|
||||||
|
"""Test detection of Steam scam patterns."""
|
||||||
|
result = automod.check_scam_links("Check out this steam-community-giveaway.xyz")
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
def test_legitimate_discord_link(self, automod: AutomodService) -> None:
|
||||||
|
"""Test that legitimate Discord links pass."""
|
||||||
|
result = automod.check_scam_links("Join us at discord.gg/example")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_suspicious_tld_with_keyword(self, automod: AutomodService) -> None:
|
||||||
|
"""Test suspicious TLD with impersonation keyword."""
|
||||||
|
result = automod.check_scam_links("Visit discord-verify.xyz to claim")
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
def test_normal_url(self, automod: AutomodService) -> None:
|
||||||
|
"""Test normal URLs pass."""
|
||||||
|
result = automod.check_scam_links("Check out https://github.com/example")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestInviteLinks:
|
||||||
|
"""Tests for Discord invite link detection."""
|
||||||
|
|
||||||
|
def test_discord_gg_invite(self, automod: AutomodService) -> None:
|
||||||
|
"""Test discord.gg invite detection."""
|
||||||
|
result = automod.check_invite_links("Join discord.gg/example", allow_invites=False)
|
||||||
|
assert result is not None
|
||||||
|
assert result.should_delete
|
||||||
|
|
||||||
|
def test_discordapp_invite(self, automod: AutomodService) -> None:
|
||||||
|
"""Test discordapp.com invite detection."""
|
||||||
|
result = automod.check_invite_links(
|
||||||
|
"Join https://discordapp.com/invite/abc123", allow_invites=False
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
def test_allowed_invites(self, automod: AutomodService) -> None:
|
||||||
|
"""Test invites pass when allowed."""
|
||||||
|
result = automod.check_invite_links("Join discord.gg/example", allow_invites=True)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestCapsDetection:
|
||||||
|
"""Tests for excessive caps detection."""
|
||||||
|
|
||||||
|
def test_excessive_caps(self, automod: AutomodService) -> None:
|
||||||
|
"""Test detection of all caps message."""
|
||||||
|
result = automod.check_all_caps("THIS IS ALL CAPS MESSAGE HERE")
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
def test_normal_caps(self, automod: AutomodService) -> None:
|
||||||
|
"""Test normal message passes."""
|
||||||
|
result = automod.check_all_caps("This is a Normal Message with Some Caps")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_short_message_ignored(self, automod: AutomodService) -> None:
|
||||||
|
"""Test short messages are ignored."""
|
||||||
|
result = automod.check_all_caps("HI THERE")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class MockBannedWord:
|
||||||
|
"""Mock BannedWord for testing without database."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pattern: str,
|
||||||
|
is_regex: bool = False,
|
||||||
|
action: str = "delete",
|
||||||
|
) -> None:
|
||||||
|
self.id = 1
|
||||||
|
self.guild_id = 123
|
||||||
|
self.pattern = pattern
|
||||||
|
self.is_regex = is_regex
|
||||||
|
self.action = action
|
||||||
|
self.reason = None
|
||||||
|
self.added_by = 456
|
||||||
|
|
||||||
|
|
||||||
|
def _make_banned_word(
|
||||||
|
pattern: str,
|
||||||
|
is_regex: bool = False,
|
||||||
|
action: str = "delete",
|
||||||
|
) -> MockBannedWord:
|
||||||
|
"""Create a mock BannedWord object for testing."""
|
||||||
|
return MockBannedWord(pattern, is_regex, action)
|
||||||
130
tests/test_ratelimit.py
Normal file
130
tests/test_ratelimit.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
"""Tests for rate limiting service."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from guardden.services.ratelimit import (
|
||||||
|
RateLimitBucket,
|
||||||
|
RateLimitConfig,
|
||||||
|
RateLimiter,
|
||||||
|
RateLimitScope,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimitBucket:
|
||||||
|
"""Tests for RateLimitBucket."""
|
||||||
|
|
||||||
|
def test_not_limited_initially(self) -> None:
|
||||||
|
"""Test bucket is not limited when empty."""
|
||||||
|
bucket = RateLimitBucket(max_requests=3, window_seconds=60)
|
||||||
|
assert bucket.is_limited() is False
|
||||||
|
assert bucket.remaining() == 3
|
||||||
|
|
||||||
|
def test_limited_after_max_requests(self) -> None:
|
||||||
|
"""Test bucket is limited after max requests."""
|
||||||
|
bucket = RateLimitBucket(max_requests=3, window_seconds=60)
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
bucket.record()
|
||||||
|
|
||||||
|
assert bucket.is_limited() is True
|
||||||
|
assert bucket.remaining() == 0
|
||||||
|
|
||||||
|
def test_remaining_decreases(self) -> None:
|
||||||
|
"""Test remaining count decreases."""
|
||||||
|
bucket = RateLimitBucket(max_requests=5, window_seconds=60)
|
||||||
|
|
||||||
|
bucket.record()
|
||||||
|
assert bucket.remaining() == 4
|
||||||
|
|
||||||
|
bucket.record()
|
||||||
|
assert bucket.remaining() == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimiter:
|
||||||
|
"""Tests for RateLimiter."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def limiter(self) -> RateLimiter:
|
||||||
|
return RateLimiter()
|
||||||
|
|
||||||
|
def test_check_not_limited(self, limiter: RateLimiter) -> None:
|
||||||
|
"""Test check returns not limited for new bucket."""
|
||||||
|
result = limiter.check("command", user_id=123, guild_id=456)
|
||||||
|
assert result.is_limited is False
|
||||||
|
|
||||||
|
def test_acquire_records_request(self, limiter: RateLimiter) -> None:
|
||||||
|
"""Test acquire records the request."""
|
||||||
|
# Configure a simple limit
|
||||||
|
limiter.configure(
|
||||||
|
"test",
|
||||||
|
RateLimitConfig(max_requests=2, window_seconds=60, scope=RateLimitScope.USER),
|
||||||
|
)
|
||||||
|
|
||||||
|
result1 = limiter.acquire("test", user_id=123)
|
||||||
|
assert result1.is_limited is False
|
||||||
|
assert result1.remaining == 1
|
||||||
|
|
||||||
|
result2 = limiter.acquire("test", user_id=123)
|
||||||
|
assert result2.is_limited is False
|
||||||
|
assert result2.remaining == 0
|
||||||
|
|
||||||
|
result3 = limiter.acquire("test", user_id=123)
|
||||||
|
assert result3.is_limited is True
|
||||||
|
|
||||||
|
def test_different_scopes(self, limiter: RateLimiter) -> None:
|
||||||
|
"""Test different scopes create different buckets."""
|
||||||
|
# User scope - shared across guilds
|
||||||
|
limiter.configure(
|
||||||
|
"user_action",
|
||||||
|
RateLimitConfig(max_requests=1, window_seconds=60, scope=RateLimitScope.USER),
|
||||||
|
)
|
||||||
|
|
||||||
|
limiter.acquire("user_action", user_id=123, guild_id=1)
|
||||||
|
result = limiter.acquire("user_action", user_id=123, guild_id=2)
|
||||||
|
assert result.is_limited is True # Same user, different guild
|
||||||
|
|
||||||
|
# Member scope - per guild
|
||||||
|
limiter.configure(
|
||||||
|
"member_action",
|
||||||
|
RateLimitConfig(max_requests=1, window_seconds=60, scope=RateLimitScope.MEMBER),
|
||||||
|
)
|
||||||
|
|
||||||
|
limiter.acquire("member_action", user_id=456, guild_id=1)
|
||||||
|
result = limiter.acquire("member_action", user_id=456, guild_id=2)
|
||||||
|
assert result.is_limited is False # Same user, different guild = different bucket
|
||||||
|
|
||||||
|
def test_reset(self, limiter: RateLimiter) -> None:
|
||||||
|
"""Test resetting a bucket."""
|
||||||
|
limiter.configure(
|
||||||
|
"test",
|
||||||
|
RateLimitConfig(max_requests=1, window_seconds=60, scope=RateLimitScope.USER),
|
||||||
|
)
|
||||||
|
|
||||||
|
limiter.acquire("test", user_id=123)
|
||||||
|
assert limiter.acquire("test", user_id=123).is_limited is True
|
||||||
|
|
||||||
|
limiter.reset("test", user_id=123)
|
||||||
|
assert limiter.acquire("test", user_id=123).is_limited is False
|
||||||
|
|
||||||
|
def test_unknown_action(self, limiter: RateLimiter) -> None:
|
||||||
|
"""Test unknown action returns not limited."""
|
||||||
|
result = limiter.acquire("unknown_action", user_id=123)
|
||||||
|
assert result.is_limited is False
|
||||||
|
assert result.remaining == 999
|
||||||
|
|
||||||
|
def test_guild_scope(self, limiter: RateLimiter) -> None:
|
||||||
|
"""Test guild-scoped rate limiting."""
|
||||||
|
limiter.configure(
|
||||||
|
"guild_action",
|
||||||
|
RateLimitConfig(max_requests=2, window_seconds=60, scope=RateLimitScope.GUILD),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Different users in same guild share the limit
|
||||||
|
limiter.acquire("guild_action", user_id=1, guild_id=100)
|
||||||
|
limiter.acquire("guild_action", user_id=2, guild_id=100)
|
||||||
|
result = limiter.acquire("guild_action", user_id=3, guild_id=100)
|
||||||
|
assert result.is_limited is True
|
||||||
|
|
||||||
|
# Different guild is not limited
|
||||||
|
result = limiter.acquire("guild_action", user_id=1, guild_id=200)
|
||||||
|
assert result.is_limited is False
|
||||||
48
tests/test_utils.py
Normal file
48
tests/test_utils.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""Tests for utility functions."""
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from guardden.cogs.moderation import parse_duration
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseDuration:
|
||||||
|
"""Tests for the parse_duration function."""
|
||||||
|
|
||||||
|
def test_parse_seconds(self) -> None:
|
||||||
|
"""Test parsing seconds."""
|
||||||
|
assert parse_duration("30s") == timedelta(seconds=30)
|
||||||
|
assert parse_duration("1s") == timedelta(seconds=1)
|
||||||
|
|
||||||
|
def test_parse_minutes(self) -> None:
|
||||||
|
"""Test parsing minutes."""
|
||||||
|
assert parse_duration("5m") == timedelta(minutes=5)
|
||||||
|
assert parse_duration("30m") == timedelta(minutes=30)
|
||||||
|
|
||||||
|
def test_parse_hours(self) -> None:
|
||||||
|
"""Test parsing hours."""
|
||||||
|
assert parse_duration("1h") == timedelta(hours=1)
|
||||||
|
assert parse_duration("24h") == timedelta(hours=24)
|
||||||
|
|
||||||
|
def test_parse_days(self) -> None:
|
||||||
|
"""Test parsing days."""
|
||||||
|
assert parse_duration("7d") == timedelta(days=7)
|
||||||
|
assert parse_duration("1d") == timedelta(days=1)
|
||||||
|
|
||||||
|
def test_parse_weeks(self) -> None:
|
||||||
|
"""Test parsing weeks."""
|
||||||
|
assert parse_duration("2w") == timedelta(weeks=2)
|
||||||
|
assert parse_duration("1w") == timedelta(weeks=1)
|
||||||
|
|
||||||
|
def test_invalid_format(self) -> None:
|
||||||
|
"""Test invalid duration formats."""
|
||||||
|
assert parse_duration("invalid") is None
|
||||||
|
assert parse_duration("") is None
|
||||||
|
assert parse_duration("10") is None
|
||||||
|
assert parse_duration("abc") is None
|
||||||
|
|
||||||
|
def test_case_insensitive(self) -> None:
|
||||||
|
"""Test that parsing is case insensitive."""
|
||||||
|
assert parse_duration("1H") == timedelta(hours=1)
|
||||||
|
assert parse_duration("30M") == timedelta(minutes=30)
|
||||||
142
tests/test_verification.py
Normal file
142
tests/test_verification.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
"""Tests for verification service."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from guardden.services.verification import (
|
||||||
|
ButtonChallengeGenerator,
|
||||||
|
CaptchaChallengeGenerator,
|
||||||
|
ChallengeType,
|
||||||
|
EmojiChallengeGenerator,
|
||||||
|
MathChallengeGenerator,
|
||||||
|
VerificationService,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestChallengeGenerators:
|
||||||
|
"""Tests for challenge generators."""
|
||||||
|
|
||||||
|
def test_button_challenge(self) -> None:
|
||||||
|
"""Test button challenge generation."""
|
||||||
|
gen = ButtonChallengeGenerator()
|
||||||
|
challenge = gen.generate()
|
||||||
|
|
||||||
|
assert challenge.challenge_type == ChallengeType.BUTTON
|
||||||
|
assert challenge.answer == "verified"
|
||||||
|
|
||||||
|
def test_captcha_challenge(self) -> None:
|
||||||
|
"""Test captcha challenge generation."""
|
||||||
|
gen = CaptchaChallengeGenerator(length=6)
|
||||||
|
challenge = gen.generate()
|
||||||
|
|
||||||
|
assert challenge.challenge_type == ChallengeType.CAPTCHA
|
||||||
|
assert len(challenge.answer) == 6
|
||||||
|
assert challenge.answer.isalnum()
|
||||||
|
|
||||||
|
def test_math_challenge(self) -> None:
|
||||||
|
"""Test math challenge generation."""
|
||||||
|
gen = MathChallengeGenerator()
|
||||||
|
challenge = gen.generate()
|
||||||
|
|
||||||
|
assert challenge.challenge_type == ChallengeType.MATH
|
||||||
|
# Answer should be a number
|
||||||
|
assert challenge.answer.lstrip("-").isdigit()
|
||||||
|
|
||||||
|
def test_emoji_challenge(self) -> None:
|
||||||
|
"""Test emoji challenge generation."""
|
||||||
|
gen = EmojiChallengeGenerator()
|
||||||
|
challenge = gen.generate()
|
||||||
|
|
||||||
|
assert challenge.challenge_type == ChallengeType.EMOJI
|
||||||
|
assert len(challenge.options) > 0
|
||||||
|
assert challenge.answer in challenge.options
|
||||||
|
|
||||||
|
|
||||||
|
class TestVerificationService:
|
||||||
|
"""Tests for verification service."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def service(self) -> VerificationService:
|
||||||
|
return VerificationService()
|
||||||
|
|
||||||
|
def test_create_challenge(self, service: VerificationService) -> None:
|
||||||
|
"""Test creating a challenge."""
|
||||||
|
pending = service.create_challenge(
|
||||||
|
user_id=123,
|
||||||
|
guild_id=456,
|
||||||
|
challenge_type=ChallengeType.BUTTON,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert pending.user_id == 123
|
||||||
|
assert pending.guild_id == 456
|
||||||
|
assert pending.challenge.challenge_type == ChallengeType.BUTTON
|
||||||
|
|
||||||
|
def test_get_pending(self, service: VerificationService) -> None:
|
||||||
|
"""Test retrieving pending verification."""
|
||||||
|
service.create_challenge(123, 456, ChallengeType.BUTTON)
|
||||||
|
|
||||||
|
pending = service.get_pending(456, 123)
|
||||||
|
assert pending is not None
|
||||||
|
assert pending.user_id == 123
|
||||||
|
|
||||||
|
# Non-existent should return None
|
||||||
|
assert service.get_pending(456, 999) is None
|
||||||
|
|
||||||
|
def test_verify_correct(self, service: VerificationService) -> None:
|
||||||
|
"""Test successful verification."""
|
||||||
|
service.create_challenge(123, 456, ChallengeType.BUTTON)
|
||||||
|
|
||||||
|
success, message = service.verify(456, 123, "verified")
|
||||||
|
assert success is True
|
||||||
|
assert "successful" in message.lower()
|
||||||
|
|
||||||
|
# Should be removed after success
|
||||||
|
assert service.get_pending(456, 123) is None
|
||||||
|
|
||||||
|
def test_verify_incorrect(self, service: VerificationService) -> None:
|
||||||
|
"""Test failed verification."""
|
||||||
|
service.create_challenge(123, 456, ChallengeType.BUTTON)
|
||||||
|
|
||||||
|
success, message = service.verify(456, 123, "wrong")
|
||||||
|
assert success is False
|
||||||
|
assert "incorrect" in message.lower()
|
||||||
|
|
||||||
|
# Should still exist
|
||||||
|
assert service.get_pending(456, 123) is not None
|
||||||
|
|
||||||
|
def test_verify_max_attempts(self, service: VerificationService) -> None:
|
||||||
|
"""Test max attempts exceeded."""
|
||||||
|
service.create_challenge(123, 456, ChallengeType.BUTTON)
|
||||||
|
|
||||||
|
# Use up all attempts
|
||||||
|
for _ in range(3):
|
||||||
|
service.verify(456, 123, "wrong")
|
||||||
|
|
||||||
|
success, message = service.verify(456, 123, "verified")
|
||||||
|
assert success is False
|
||||||
|
assert "too many" in message.lower()
|
||||||
|
|
||||||
|
def test_verify_no_pending(self, service: VerificationService) -> None:
|
||||||
|
"""Test verification with no pending challenge."""
|
||||||
|
success, message = service.verify(456, 123, "verified")
|
||||||
|
assert success is False
|
||||||
|
assert "no pending" in message.lower()
|
||||||
|
|
||||||
|
def test_cancel(self, service: VerificationService) -> None:
|
||||||
|
"""Test canceling verification."""
|
||||||
|
service.create_challenge(123, 456, ChallengeType.BUTTON)
|
||||||
|
|
||||||
|
assert service.cancel(456, 123) is True
|
||||||
|
assert service.get_pending(456, 123) is None
|
||||||
|
|
||||||
|
# Cancel non-existent returns False
|
||||||
|
assert service.cancel(456, 123) is False
|
||||||
|
|
||||||
|
def test_pending_count(self, service: VerificationService) -> None:
|
||||||
|
"""Test pending count per guild."""
|
||||||
|
service.create_challenge(1, 456, ChallengeType.BUTTON)
|
||||||
|
service.create_challenge(2, 456, ChallengeType.BUTTON)
|
||||||
|
service.create_challenge(3, 789, ChallengeType.BUTTON)
|
||||||
|
|
||||||
|
assert service.get_pending_count(456) == 2
|
||||||
|
assert service.get_pending_count(789) == 1
|
||||||
|
assert service.get_pending_count(999) == 0
|
||||||
Reference in New Issue
Block a user