Implement GuardDen Discord moderation bot
Features: - Core moderation: warn, kick, ban, timeout, strike system - Automod: banned words filter, scam detection, anti-spam, link filtering - AI moderation: Claude/OpenAI integration, NSFW detection, phishing analysis - Verification system: button, captcha, math, emoji challenges - Rate limiting system with configurable scopes - Event logging: joins, leaves, message edits/deletes, voice activity - Per-guild configuration with caching - Docker deployment support Bug fixes applied: - Fixed await on session.delete() in guild_config.py - Fixed memory leak in AI moderation message tracking (use deque) - Added error handling to bot shutdown - Added error handling to timeout command - Removed unused Literal import - Added prefix validation - Added image analysis limit (3 per message) - Fixed test mock for SQLAlchemy model
This commit is contained in:
21
.env.example
Normal file
21
.env.example
Normal file
@@ -0,0 +1,21 @@
|
||||
# Discord Bot Configuration
|
||||
GUARDDEN_DISCORD_TOKEN=your_discord_bot_token_here
|
||||
GUARDDEN_DISCORD_PREFIX=!
|
||||
|
||||
# Database Configuration (for local development without Docker)
|
||||
GUARDDEN_DATABASE_URL=postgresql://guardden:guardden@localhost:5432/guardden
|
||||
|
||||
# Logging
|
||||
GUARDDEN_LOG_LEVEL=INFO
|
||||
|
||||
# AI Configuration (optional)
|
||||
# Options: none, anthropic, openai
|
||||
GUARDDEN_AI_PROVIDER=none
|
||||
|
||||
# Anthropic API key (required if AI_PROVIDER=anthropic)
|
||||
# Get your key at: https://console.anthropic.com/
|
||||
GUARDDEN_ANTHROPIC_API_KEY=
|
||||
|
||||
# OpenAI API key (required if AI_PROVIDER=openai)
|
||||
# Get your key at: https://platform.openai.com/api-keys
|
||||
GUARDDEN_OPENAI_API_KEY=
|
||||
56
.gitignore
vendored
Normal file
56
.gitignore
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Virtual environments
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
.venv/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Environment
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# Data
|
||||
data/
|
||||
*.db
|
||||
*.sqlite3
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# Testing
|
||||
.coverage
|
||||
htmlcov/
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
|
||||
# Docker
|
||||
docker-compose.override.yml
|
||||
103
CLAUDE.md
Normal file
103
CLAUDE.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
GuardDen is a Discord moderation bot built with discord.py, PostgreSQL, and optional AI integration (Claude/OpenAI). Self-hosted with Docker support.
|
||||
|
||||
## Commands
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pip install -e ".[dev,ai]"
|
||||
|
||||
# Run the bot
|
||||
python -m guardden
|
||||
|
||||
# Run tests
|
||||
pytest
|
||||
|
||||
# Run single test
|
||||
pytest tests/test_verification.py::TestVerificationService::test_verify_correct
|
||||
|
||||
# Lint and format
|
||||
ruff check src tests
|
||||
ruff format src tests
|
||||
|
||||
# Type checking
|
||||
mypy src
|
||||
|
||||
# Docker deployment
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
- `src/guardden/bot.py` - Main bot class (`GuardDen`) extending `commands.Bot`, manages lifecycle and services
|
||||
- `src/guardden/config.py` - Pydantic settings loaded from environment variables (prefix: `GUARDDEN_`)
|
||||
- `src/guardden/models/` - SQLAlchemy 2.0 async models for PostgreSQL
|
||||
- `src/guardden/services/` - Business logic (database, guild config, automod, AI, verification, rate limiting)
|
||||
- `src/guardden/cogs/` - Discord command groups (events, moderation, admin, automod, ai_moderation, verification)
|
||||
|
||||
## Key Patterns
|
||||
|
||||
- All database operations use async SQLAlchemy with `asyncpg`
|
||||
- Guild configurations are cached in `GuildConfigService._cache`
|
||||
- Discord snowflake IDs stored as `BigInteger` in PostgreSQL
|
||||
- Moderation actions logged to `ModerationLog` table with automatic strike escalation
|
||||
- Environment variables: `GUARDDEN_DISCORD_TOKEN`, `GUARDDEN_DATABASE_URL`
|
||||
|
||||
## Automod System
|
||||
|
||||
- `AutomodService` in `services/automod.py` handles rule-based content filtering
|
||||
- Checks run in order: banned words → scam links → spam → invite links
|
||||
- Spam tracking uses per-guild, per-user trackers with automatic cleanup
|
||||
- Scam detection uses compiled regex patterns in `SCAM_PATTERNS` list
|
||||
- Results return `AutomodResult` dataclass with actions to take
|
||||
|
||||
## AI Moderation System
|
||||
|
||||
- `services/ai/` contains provider abstraction and implementations
|
||||
- `AIProvider` base class defines interface: `moderate_text()`, `analyze_image()`, `analyze_phishing()`
|
||||
- `AnthropicProvider` and `OpenAIProvider` implement the interface
|
||||
- `NullProvider` used when AI is disabled (returns empty results)
|
||||
- Factory pattern via `create_ai_provider(provider, api_key)`
|
||||
- `ModerationResult` includes severity scoring based on confidence + category weights
|
||||
- Sensitivity setting (0-100) adjusts thresholds per guild
|
||||
|
||||
## Verification System
|
||||
|
||||
- `VerificationService` in `services/verification.py` manages challenges
|
||||
- Challenge types: button, captcha, math, emoji (via `ChallengeGenerator` classes)
|
||||
- `PendingVerification` tracks user challenges with expiry and attempt limits
|
||||
- Discord UI components in `cogs/verification.py`: `VerifyButton`, `EmojiButton`, `CaptchaModal`
|
||||
- Background task cleans up expired verifications every 5 minutes
|
||||
|
||||
## Rate Limiting System
|
||||
|
||||
- `RateLimiter` in `services/ratelimit.py` provides general-purpose rate limiting
|
||||
- Scopes: USER (global), MEMBER (per-guild), CHANNEL, GUILD
|
||||
- `@ratelimit()` decorator for easy command rate limiting
|
||||
- `get_rate_limiter()` returns singleton instance
|
||||
- Default limits configured for commands, moderation, verification, messages
|
||||
|
||||
## Adding New Cogs
|
||||
|
||||
1. Create file in `src/guardden/cogs/`
|
||||
2. Implement `setup(bot)` async function
|
||||
3. Add cog path to `_load_cogs()` in `bot.py`
|
||||
|
||||
## Adding New AI Provider
|
||||
|
||||
1. Create `src/guardden/services/ai/newprovider.py`
|
||||
2. Implement `AIProvider` abstract class
|
||||
3. Add to factory in `services/ai/factory.py`
|
||||
4. Add config option in `config.py`
|
||||
|
||||
## Adding New Challenge Type
|
||||
|
||||
1. Create new `ChallengeGenerator` subclass in `services/verification.py`
|
||||
2. Add to `ChallengeType` enum
|
||||
3. Register in `VerificationService._generators`
|
||||
4. Create corresponding UI components in `cogs/verification.py` if needed
|
||||
27
Dockerfile
Normal file
27
Dockerfile
Normal file
@@ -0,0 +1,27 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc \
|
||||
libpq-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml ./
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -e .
|
||||
|
||||
# Copy application code
|
||||
COPY src/ ./src/
|
||||
COPY migrations/ ./migrations/
|
||||
COPY alembic.ini ./
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd -m -u 1000 guardden && chown -R guardden:guardden /app
|
||||
USER guardden
|
||||
|
||||
# Run the bot
|
||||
CMD ["python", "-m", "guardden"]
|
||||
317
README.md
317
README.md
@@ -1,3 +1,320 @@
|
||||
# GuardDen
|
||||
|
||||
GuardDen is a comprehensive Discord moderation bot designed to protect your community while maintaining a warm, welcoming environment. Built with privacy and self-hosting in mind, GuardDen combines AI-powered content filtering with traditional moderation tools to create a safe space for your members.
|
||||
|
||||
## Features
|
||||
|
||||
### Core Moderation
|
||||
- **Warn, Kick, Ban, Timeout** - Standard moderation commands with logging
|
||||
- **Strike System** - Configurable point-based system with automatic escalation
|
||||
- **Moderation History** - Track all actions taken against users
|
||||
- **Bulk Message Deletion** - Purge up to 100 messages at once
|
||||
|
||||
### Automod
|
||||
- **Banned Words Filter** - Block words/phrases with regex support
|
||||
- **Scam Detection** - Automatic detection of phishing/scam links
|
||||
- **Anti-Spam** - Rate limiting, duplicate detection, mass mention protection
|
||||
- **Link Filtering** - Block Discord invites and suspicious URLs
|
||||
|
||||
### AI Moderation
|
||||
- **Text Analysis** - AI-powered content moderation using Claude or GPT
|
||||
- **NSFW Image Detection** - Automatic flagging of inappropriate images
|
||||
- **Phishing Analysis** - AI-enhanced detection of scam URLs
|
||||
- **Configurable Sensitivity** - Adjust strictness per server (0-100)
|
||||
|
||||
### Verification System
|
||||
- **Multiple Challenge Types** - Button, captcha, math problems, emoji selection
|
||||
- **Automatic New Member Verification** - Challenge users on join
|
||||
- **Configurable Verified Role** - Auto-assign role on successful verification
|
||||
- **Rate Limited** - Prevents verification spam
|
||||
|
||||
### Logging
|
||||
- Member joins/leaves
|
||||
- Message edits and deletions
|
||||
- Voice channel activity
|
||||
- Ban/unban events
|
||||
- All moderation actions
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
- Python 3.11+
|
||||
- PostgreSQL 15+
|
||||
- Discord Bot Token ([Discord Developer Portal](https://discord.com/developers/applications))
|
||||
- (Optional) Anthropic or OpenAI API key for AI features
|
||||
|
||||
### Docker Deployment (Recommended)
|
||||
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/yourusername/guardden.git
|
||||
cd guardden
|
||||
```
|
||||
|
||||
2. Create your environment file:
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Edit .env and add your Discord token
|
||||
```
|
||||
|
||||
3. Start with Docker Compose:
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
### Local Development
|
||||
|
||||
1. Create a virtual environment:
|
||||
```bash
|
||||
python -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
```
|
||||
|
||||
2. Install dependencies:
|
||||
```bash
|
||||
pip install -e ".[dev,ai]"
|
||||
```
|
||||
|
||||
3. Set up environment variables:
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Edit .env with your configuration
|
||||
```
|
||||
|
||||
4. Start PostgreSQL (or use Docker):
|
||||
```bash
|
||||
docker compose up db -d
|
||||
```
|
||||
|
||||
5. Run the bot:
|
||||
```bash
|
||||
python -m guardden
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| `GUARDDEN_DISCORD_TOKEN` | Your Discord bot token | Required |
|
||||
| `GUARDDEN_DISCORD_PREFIX` | Default command prefix | `!` |
|
||||
| `GUARDDEN_DATABASE_URL` | PostgreSQL connection URL | `postgresql://guardden:guardden@localhost:5432/guardden` |
|
||||
| `GUARDDEN_LOG_LEVEL` | Logging level | `INFO` |
|
||||
| `GUARDDEN_AI_PROVIDER` | AI provider (anthropic/openai/none) | `none` |
|
||||
| `GUARDDEN_ANTHROPIC_API_KEY` | Anthropic API key (if using Claude) | - |
|
||||
| `GUARDDEN_OPENAI_API_KEY` | OpenAI API key (if using GPT) | - |
|
||||
|
||||
### Per-Guild Settings
|
||||
|
||||
Each server can configure:
|
||||
- Command prefix
|
||||
- Log channels (general and moderation)
|
||||
- Welcome channel
|
||||
- Mute role and verified role
|
||||
- Automod toggles (spam, links, banned words)
|
||||
- Strike action thresholds
|
||||
- AI moderation settings (enabled, sensitivity, NSFW detection)
|
||||
- Verification settings (type, enabled)
|
||||
|
||||
## Commands
|
||||
|
||||
### Moderation
|
||||
|
||||
| Command | Permission | Description |
|
||||
|---------|------------|-------------|
|
||||
| `!warn <user> [reason]` | Kick Members | Warn a user |
|
||||
| `!strike <user> [points] [reason]` | Kick Members | Add strikes to a user |
|
||||
| `!strikes <user>` | Kick Members | View user's strikes |
|
||||
| `!timeout <user> <duration> [reason]` | Moderate Members | Timeout a user (e.g., 1h, 30m, 7d) |
|
||||
| `!untimeout <user>` | Moderate Members | Remove timeout |
|
||||
| `!kick <user> [reason]` | Kick Members | Kick a user |
|
||||
| `!ban <user> [reason]` | Ban Members | Ban a user |
|
||||
| `!unban <user_id> [reason]` | Ban Members | Unban a user by ID |
|
||||
| `!purge <amount>` | Manage Messages | Delete multiple messages (max 100) |
|
||||
| `!modlogs <user>` | Kick Members | View moderation history |
|
||||
|
||||
### Configuration (Admin only)
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `!config` | View current configuration |
|
||||
| `!config prefix <prefix>` | Set command prefix |
|
||||
| `!config logchannel [#channel]` | Set general log channel |
|
||||
| `!config modlogchannel [#channel]` | Set moderation log channel |
|
||||
| `!config welcomechannel [#channel]` | Set welcome channel |
|
||||
| `!config muterole [@role]` | Set mute role |
|
||||
| `!config automod <true/false>` | Toggle automod |
|
||||
| `!config antispam <true/false>` | Toggle anti-spam |
|
||||
| `!config linkfilter <true/false>` | Toggle link filtering |
|
||||
|
||||
### Banned Words
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `!bannedwords` | List all banned words |
|
||||
| `!bannedwords add <word> [action] [is_regex]` | Add a banned word |
|
||||
| `!bannedwords remove <id>` | Remove a banned word by ID |
|
||||
|
||||
### Automod
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `!automod` | View automod status |
|
||||
| `!automod test <text>` | Test text against filters |
|
||||
|
||||
### AI Moderation (Admin only)
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `!ai` | View AI moderation settings |
|
||||
| `!ai enable` | Enable AI moderation |
|
||||
| `!ai disable` | Disable AI moderation |
|
||||
| `!ai sensitivity <0-100>` | Set AI sensitivity level |
|
||||
| `!ai nsfw <true/false>` | Toggle NSFW image detection |
|
||||
| `!ai analyze <text>` | Test AI analysis on text |
|
||||
|
||||
### Verification (Admin only)
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `!verify` | Request verification (for users) |
|
||||
| `!verify setup` | View verification setup status |
|
||||
| `!verify enable` | Enable verification for new members |
|
||||
| `!verify disable` | Disable verification |
|
||||
| `!verify role @role` | Set the verified role |
|
||||
| `!verify type <type>` | Set verification type (button/captcha/math/emoji) |
|
||||
| `!verify test [type]` | Test a verification challenge |
|
||||
| `!verify reset @user` | Reset verification for a user |
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
guardden/
|
||||
├── src/guardden/
|
||||
│ ├── bot.py # Main bot class
|
||||
│ ├── config.py # Settings management
|
||||
│ ├── cogs/ # Discord command groups
|
||||
│ │ ├── admin.py # Configuration commands
|
||||
│ │ ├── ai_moderation.py # AI-powered moderation
|
||||
│ │ ├── automod.py # Automatic moderation
|
||||
│ │ ├── events.py # Event logging
|
||||
│ │ ├── moderation.py # Moderation commands
|
||||
│ │ └── verification.py # Member verification
|
||||
│ ├── models/ # Database models
|
||||
│ │ ├── guild.py # Guild settings, banned words
|
||||
│ │ └── moderation.py # Logs, strikes, notes
|
||||
│ └── services/ # Business logic
|
||||
│ ├── ai/ # AI provider implementations
|
||||
│ ├── automod.py # Content filtering
|
||||
│ ├── database.py # DB connections
|
||||
│ ├── guild_config.py # Config caching
|
||||
│ ├── ratelimit.py # Rate limiting
|
||||
│ └── verification.py # Verification challenges
|
||||
├── tests/ # Test suite
|
||||
├── migrations/ # Database migrations
|
||||
├── docker-compose.yml # Docker deployment
|
||||
└── pyproject.toml # Dependencies
|
||||
```
|
||||
|
||||
## Verification System
|
||||
|
||||
GuardDen includes a verification system to protect your server from bots and raids.
|
||||
|
||||
### Challenge Types
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `button` | Simple button click (default, easiest) |
|
||||
| `captcha` | Text-based captcha code entry |
|
||||
| `math` | Solve a simple math problem |
|
||||
| `emoji` | Select the correct emoji from options |
|
||||
|
||||
### Setup
|
||||
|
||||
1. Create a verified role in your server
|
||||
2. Configure the role permissions (verified members get full access)
|
||||
3. Set up verification:
|
||||
```
|
||||
!verify role @Verified
|
||||
!verify type captcha
|
||||
!verify enable
|
||||
```
|
||||
|
||||
### How It Works
|
||||
|
||||
1. New member joins the server
|
||||
2. Bot sends verification challenge via DM (or channel if DMs disabled)
|
||||
3. Member completes the challenge
|
||||
4. Bot assigns the verified role
|
||||
5. Member gains access to the server
|
||||
|
||||
## AI Moderation
|
||||
|
||||
GuardDen supports AI-powered content moderation using either Anthropic's Claude or OpenAI's GPT models.
|
||||
|
||||
### Setup
|
||||
|
||||
1. Set the AI provider in your environment:
|
||||
```bash
|
||||
GUARDDEN_AI_PROVIDER=anthropic # or "openai"
|
||||
GUARDDEN_ANTHROPIC_API_KEY=sk-ant-... # if using Claude
|
||||
GUARDDEN_OPENAI_API_KEY=sk-... # if using OpenAI
|
||||
```
|
||||
|
||||
2. Enable AI moderation per server:
|
||||
```
|
||||
!ai enable
|
||||
!ai sensitivity 50 # 0=lenient, 100=strict
|
||||
!ai nsfw true # Enable NSFW image detection
|
||||
```
|
||||
|
||||
### Content Categories
|
||||
|
||||
The AI analyzes content for:
|
||||
- **Harassment** - Personal attacks, bullying
|
||||
- **Hate Speech** - Discrimination, slurs
|
||||
- **Sexual Content** - Explicit material
|
||||
- **Violence** - Threats, graphic content
|
||||
- **Self-Harm** - Suicide/self-injury content
|
||||
- **Scams** - Phishing, fraud attempts
|
||||
- **Spam** - Promotional, low-quality content
|
||||
|
||||
### How It Works
|
||||
|
||||
1. Messages are analyzed by the AI provider
|
||||
2. Results include confidence scores and severity ratings
|
||||
3. Actions are taken based on guild sensitivity settings
|
||||
4. All AI actions are logged to the mod log channel
|
||||
|
||||
## Development
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
pytest
|
||||
pytest -v # Verbose output
|
||||
pytest tests/test_automod.py # Specific file
|
||||
pytest -k "test_scam" # Filter by name
|
||||
```
|
||||
|
||||
### Code Quality
|
||||
|
||||
```bash
|
||||
ruff check src tests # Linting
|
||||
ruff format src tests # Formatting
|
||||
mypy src # Type checking
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT License - see LICENSE file for details.
|
||||
|
||||
## Roadmap
|
||||
|
||||
- [x] AI-powered content moderation (Claude/OpenAI integration)
|
||||
- [x] NSFW image detection
|
||||
- [x] Verification/captcha system
|
||||
- [x] Rate limiting
|
||||
- [ ] Voice channel moderation
|
||||
- [ ] Web dashboard
|
||||
|
||||
43
alembic.ini
Normal file
43
alembic.ini
Normal file
@@ -0,0 +1,43 @@
|
||||
[alembic]
|
||||
script_location = migrations
|
||||
prepend_sys_path = .
|
||||
version_path_separator = os
|
||||
|
||||
# Use async driver
|
||||
sqlalchemy.url = postgresql+asyncpg://guardden:guardden@localhost:5432/guardden
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
44
docker-compose.yml
Normal file
44
docker-compose.yml
Normal file
@@ -0,0 +1,44 @@
|
||||
services:
|
||||
bot:
|
||||
build: .
|
||||
container_name: guardden-bot
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
- GUARDDEN_DISCORD_TOKEN=${GUARDDEN_DISCORD_TOKEN}
|
||||
- GUARDDEN_DATABASE_URL=postgresql://guardden:guardden@db:5432/guardden
|
||||
- GUARDDEN_LOG_LEVEL=${GUARDDEN_LOG_LEVEL:-INFO}
|
||||
- GUARDDEN_AI_PROVIDER=${GUARDDEN_AI_PROVIDER:-none}
|
||||
- GUARDDEN_ANTHROPIC_API_KEY=${GUARDDEN_ANTHROPIC_API_KEY:-}
|
||||
- GUARDDEN_OPENAI_API_KEY=${GUARDDEN_OPENAI_API_KEY:-}
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
networks:
|
||||
- guardden
|
||||
|
||||
db:
|
||||
image: postgres:15-alpine
|
||||
container_name: guardden-db
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- POSTGRES_USER=guardden
|
||||
- POSTGRES_PASSWORD=guardden
|
||||
- POSTGRES_DB=guardden
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U guardden -d guardden"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
networks:
|
||||
- guardden
|
||||
|
||||
networks:
|
||||
guardden:
|
||||
driver: bridge
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
65
migrations/env.py
Normal file
65
migrations/env.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Alembic environment configuration."""
|
||||
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
from guardden.models import Base
|
||||
|
||||
config = context.config
|
||||
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
"""Run migrations with the given connection."""
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""Run migrations in 'online' mode with async engine."""
|
||||
connectable = async_engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
26
migrations/script.py.mako
Normal file
26
migrations/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
98
pyproject.toml
Normal file
98
pyproject.toml
Normal file
@@ -0,0 +1,98 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "guardden"
|
||||
version = "0.1.0"
|
||||
description = "A comprehensive Discord moderation bot with AI-powered content filtering"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
requires-python = ">=3.11"
|
||||
authors = [
|
||||
{name = "GuardDen Team"}
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"discord.py>=2.3.0",
|
||||
"asyncpg>=0.29.0",
|
||||
"pydantic>=2.5.0",
|
||||
"pydantic-settings>=2.1.0",
|
||||
"aiohttp>=3.9.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"alembic>=1.13.0",
|
||||
"sqlalchemy>=2.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.4.0",
|
||||
"pytest-asyncio>=0.23.0",
|
||||
"pytest-cov>=4.1.0",
|
||||
"ruff>=0.1.0",
|
||||
"mypy>=1.7.0",
|
||||
"pre-commit>=3.6.0",
|
||||
]
|
||||
ai = [
|
||||
"anthropic>=0.18.0",
|
||||
"openai>=1.10.0",
|
||||
"pillow>=10.2.0",
|
||||
]
|
||||
voice = [
|
||||
"speechrecognition>=3.10.0",
|
||||
"pydub>=0.25.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
guardden = "guardden.__main__:main"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py311"
|
||||
line-length = 100
|
||||
src = ["src", "tests"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"F", # Pyflakes
|
||||
"I", # isort
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"UP", # pyupgrade
|
||||
"ARG", # flake8-unused-arguments
|
||||
"SIM", # flake8-simplify
|
||||
]
|
||||
ignore = [
|
||||
"E501", # line too long (handled by formatter)
|
||||
"B008", # do not perform function calls in argument defaults
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["guardden"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
addopts = "-v --tb=short"
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.11"
|
||||
strict = true
|
||||
warn_return_any = true
|
||||
warn_unused_ignores = true
|
||||
plugins = ["pydantic.mypy"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["discord.*", "asyncpg.*"]
|
||||
ignore_missing_imports = true
|
||||
3
src/guardden/__init__.py
Normal file
3
src/guardden/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""GuardDen - A comprehensive Discord moderation bot."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
40
src/guardden/__main__.py
Normal file
40
src/guardden/__main__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Entry point for GuardDen bot."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
from guardden.config import get_settings
|
||||
from guardden.utils import setup_logging
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Run the GuardDen bot."""
|
||||
try:
|
||||
settings = get_settings()
|
||||
except Exception as e:
|
||||
print(f"Failed to load configuration: {e}", file=sys.stderr)
|
||||
print("Make sure GUARDDEN_DISCORD_TOKEN is set.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
setup_logging(settings.log_level)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
bot = GuardDen(settings)
|
||||
|
||||
async def runner() -> None:
|
||||
async with bot:
|
||||
await bot.start(settings.discord_token.get_secret_value())
|
||||
|
||||
try:
|
||||
asyncio.run(runner())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received keyboard interrupt, shutting down...")
|
||||
except Exception as e:
|
||||
logger.exception(f"Fatal error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
131
src/guardden/bot.py
Normal file
131
src/guardden/bot.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Main bot class for GuardDen."""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from guardden.config import Settings
|
||||
from guardden.services.ai import AIProvider, create_ai_provider
|
||||
from guardden.services.database import Database
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from guardden.services.guild_config import GuildConfigService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GuardDen(commands.Bot):
|
||||
"""The main GuardDen Discord bot."""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self.settings = settings
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
intents.members = True
|
||||
intents.voice_states = True
|
||||
|
||||
super().__init__(
|
||||
command_prefix=self._get_prefix,
|
||||
intents=intents,
|
||||
help_command=commands.DefaultHelpCommand(),
|
||||
)
|
||||
|
||||
# Services
|
||||
self.database = Database(settings)
|
||||
self.guild_config: "GuildConfigService | None" = None
|
||||
self.ai_provider: AIProvider | None = None
|
||||
|
||||
async def _get_prefix(self, bot: "GuardDen", message: discord.Message) -> list[str]:
|
||||
"""Get the command prefix for a guild."""
|
||||
if not message.guild:
|
||||
return [self.settings.discord_prefix]
|
||||
|
||||
if self.guild_config:
|
||||
config = await self.guild_config.get_config(message.guild.id)
|
||||
if config:
|
||||
return [config.prefix]
|
||||
|
||||
return [self.settings.discord_prefix]
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
"""Called when the bot is starting up."""
|
||||
logger.info("Starting GuardDen setup...")
|
||||
|
||||
# Connect to database
|
||||
await self.database.connect()
|
||||
await self.database.create_tables()
|
||||
|
||||
# Initialize services
|
||||
from guardden.services.guild_config import GuildConfigService
|
||||
|
||||
self.guild_config = GuildConfigService(self.database)
|
||||
|
||||
# Initialize AI provider
|
||||
api_key = None
|
||||
if self.settings.ai_provider == "anthropic" and self.settings.anthropic_api_key:
|
||||
api_key = self.settings.anthropic_api_key.get_secret_value()
|
||||
elif self.settings.ai_provider == "openai" and self.settings.openai_api_key:
|
||||
api_key = self.settings.openai_api_key.get_secret_value()
|
||||
|
||||
self.ai_provider = create_ai_provider(self.settings.ai_provider, api_key)
|
||||
|
||||
# Load cogs
|
||||
await self._load_cogs()
|
||||
|
||||
logger.info("GuardDen setup complete")
|
||||
|
||||
async def _load_cogs(self) -> None:
|
||||
"""Load all cog extensions."""
|
||||
cogs = [
|
||||
"guardden.cogs.events",
|
||||
"guardden.cogs.moderation",
|
||||
"guardden.cogs.admin",
|
||||
"guardden.cogs.automod",
|
||||
"guardden.cogs.ai_moderation",
|
||||
"guardden.cogs.verification",
|
||||
]
|
||||
|
||||
for cog in cogs:
|
||||
try:
|
||||
await self.load_extension(cog)
|
||||
logger.info(f"Loaded cog: {cog}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load cog {cog}: {e}")
|
||||
|
||||
async def on_ready(self) -> None:
|
||||
"""Called when the bot is fully connected and ready."""
|
||||
if self.user:
|
||||
logger.info(f"Logged in as {self.user} (ID: {self.user.id})")
|
||||
logger.info(f"Connected to {len(self.guilds)} guild(s)")
|
||||
|
||||
# Set presence
|
||||
activity = discord.Activity(
|
||||
type=discord.ActivityType.watching,
|
||||
name="over your community",
|
||||
)
|
||||
await self.change_presence(activity=activity)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up when shutting down."""
|
||||
logger.info("Shutting down GuardDen...")
|
||||
if self.ai_provider:
|
||||
try:
|
||||
await self.ai_provider.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing AI provider: {e}")
|
||||
await self.database.disconnect()
|
||||
await super().close()
|
||||
|
||||
async def on_guild_join(self, guild: discord.Guild) -> None:
|
||||
"""Called when the bot joins a new guild."""
|
||||
logger.info(f"Joined guild: {guild.name} (ID: {guild.id})")
|
||||
|
||||
if self.guild_config:
|
||||
await self.guild_config.create_guild(guild)
|
||||
|
||||
async def on_guild_remove(self, guild: discord.Guild) -> None:
|
||||
"""Called when the bot is removed from a guild."""
|
||||
logger.info(f"Removed from guild: {guild.name} (ID: {guild.id})")
|
||||
1
src/guardden/cogs/__init__.py
Normal file
1
src/guardden/cogs/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Discord cogs for GuardDen."""
|
||||
255
src/guardden/cogs/admin.py
Normal file
255
src/guardden/cogs/admin.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Admin commands for bot configuration."""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Admin(commands.Cog):
|
||||
"""Administrative commands for bot configuration."""
|
||||
|
||||
def __init__(self, bot: GuardDen) -> None:
|
||||
self.bot = bot
|
||||
|
||||
async def cog_check(self, ctx: commands.Context) -> bool:
|
||||
"""Ensure only administrators can use these commands."""
|
||||
if not ctx.guild:
|
||||
return False
|
||||
return ctx.author.guild_permissions.administrator
|
||||
|
||||
@commands.group(name="config", invoke_without_command=True)
|
||||
@commands.guild_only()
|
||||
async def config(self, ctx: commands.Context) -> None:
|
||||
"""View or modify bot configuration."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
|
||||
if not config:
|
||||
await ctx.send("No configuration found. Run a config command to initialize.")
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title=f"Configuration for {ctx.guild.name}",
|
||||
color=discord.Color.blue(),
|
||||
)
|
||||
|
||||
# General settings
|
||||
embed.add_field(name="Prefix", value=f"`{config.prefix}`", inline=True)
|
||||
embed.add_field(name="Locale", value=config.locale, inline=True)
|
||||
embed.add_field(name="\u200b", value="\u200b", inline=True)
|
||||
|
||||
# Channels
|
||||
log_ch = ctx.guild.get_channel(config.log_channel_id) if config.log_channel_id else None
|
||||
mod_log_ch = (
|
||||
ctx.guild.get_channel(config.mod_log_channel_id) if config.mod_log_channel_id else None
|
||||
)
|
||||
welcome_ch = (
|
||||
ctx.guild.get_channel(config.welcome_channel_id) if config.welcome_channel_id else None
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="Log Channel", value=log_ch.mention if log_ch else "Not set", inline=True
|
||||
)
|
||||
embed.add_field(
|
||||
name="Mod Log Channel",
|
||||
value=mod_log_ch.mention if mod_log_ch else "Not set",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Welcome Channel",
|
||||
value=welcome_ch.mention if welcome_ch else "Not set",
|
||||
inline=True,
|
||||
)
|
||||
|
||||
# Features
|
||||
features = []
|
||||
if config.automod_enabled:
|
||||
features.append("AutoMod")
|
||||
if config.anti_spam_enabled:
|
||||
features.append("Anti-Spam")
|
||||
if config.link_filter_enabled:
|
||||
features.append("Link Filter")
|
||||
if config.ai_moderation_enabled:
|
||||
features.append("AI Moderation")
|
||||
if config.verification_enabled:
|
||||
features.append("Verification")
|
||||
|
||||
embed.add_field(
|
||||
name="Enabled Features",
|
||||
value=", ".join(features) if features else "None",
|
||||
inline=False,
|
||||
)
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@config.command(name="prefix")
|
||||
@commands.guild_only()
|
||||
async def config_prefix(self, ctx: commands.Context, prefix: str) -> None:
|
||||
"""Set the command prefix for this server."""
|
||||
if not prefix or not prefix.strip():
|
||||
await ctx.send("Prefix cannot be empty or whitespace only.")
|
||||
return
|
||||
|
||||
if len(prefix) > 10:
|
||||
await ctx.send("Prefix must be 10 characters or less.")
|
||||
return
|
||||
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, prefix=prefix)
|
||||
await ctx.send(f"Command prefix set to `{prefix}`")
|
||||
|
||||
@config.command(name="logchannel")
|
||||
@commands.guild_only()
|
||||
async def config_log_channel(
|
||||
self, ctx: commands.Context, channel: discord.TextChannel | None = None
|
||||
) -> None:
|
||||
"""Set the channel for general event logs."""
|
||||
channel_id = channel.id if channel else None
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, log_channel_id=channel_id)
|
||||
|
||||
if channel:
|
||||
await ctx.send(f"Log channel set to {channel.mention}")
|
||||
else:
|
||||
await ctx.send("Log channel has been disabled.")
|
||||
|
||||
@config.command(name="modlogchannel")
|
||||
@commands.guild_only()
|
||||
async def config_mod_log_channel(
|
||||
self, ctx: commands.Context, channel: discord.TextChannel | None = None
|
||||
) -> None:
|
||||
"""Set the channel for moderation action logs."""
|
||||
channel_id = channel.id if channel else None
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, mod_log_channel_id=channel_id)
|
||||
|
||||
if channel:
|
||||
await ctx.send(f"Moderation log channel set to {channel.mention}")
|
||||
else:
|
||||
await ctx.send("Moderation log channel has been disabled.")
|
||||
|
||||
@config.command(name="welcomechannel")
|
||||
@commands.guild_only()
|
||||
async def config_welcome_channel(
|
||||
self, ctx: commands.Context, channel: discord.TextChannel | None = None
|
||||
) -> None:
|
||||
"""Set the welcome channel for new members."""
|
||||
channel_id = channel.id if channel else None
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, welcome_channel_id=channel_id)
|
||||
|
||||
if channel:
|
||||
await ctx.send(f"Welcome channel set to {channel.mention}")
|
||||
else:
|
||||
await ctx.send("Welcome channel has been disabled.")
|
||||
|
||||
@config.command(name="muterole")
|
||||
@commands.guild_only()
|
||||
async def config_mute_role(
|
||||
self, ctx: commands.Context, role: discord.Role | None = None
|
||||
) -> None:
|
||||
"""Set the role to assign when muting members."""
|
||||
role_id = role.id if role else None
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, mute_role_id=role_id)
|
||||
|
||||
if role:
|
||||
await ctx.send(f"Mute role set to {role.mention}")
|
||||
else:
|
||||
await ctx.send("Mute role has been cleared.")
|
||||
|
||||
@config.command(name="automod")
|
||||
@commands.guild_only()
|
||||
async def config_automod(self, ctx: commands.Context, enabled: bool) -> None:
|
||||
"""Enable or disable automod features."""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, automod_enabled=enabled)
|
||||
status = "enabled" if enabled else "disabled"
|
||||
await ctx.send(f"AutoMod has been {status}.")
|
||||
|
||||
@config.command(name="antispam")
|
||||
@commands.guild_only()
|
||||
async def config_antispam(self, ctx: commands.Context, enabled: bool) -> None:
|
||||
"""Enable or disable anti-spam protection."""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, anti_spam_enabled=enabled)
|
||||
status = "enabled" if enabled else "disabled"
|
||||
await ctx.send(f"Anti-spam has been {status}.")
|
||||
|
||||
@config.command(name="linkfilter")
|
||||
@commands.guild_only()
|
||||
async def config_linkfilter(self, ctx: commands.Context, enabled: bool) -> None:
|
||||
"""Enable or disable link filtering."""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, link_filter_enabled=enabled)
|
||||
status = "enabled" if enabled else "disabled"
|
||||
await ctx.send(f"Link filter has been {status}.")
|
||||
|
||||
@commands.group(name="bannedwords", aliases=["bw"], invoke_without_command=True)
|
||||
@commands.guild_only()
|
||||
async def banned_words(self, ctx: commands.Context) -> None:
|
||||
"""Manage banned words list."""
|
||||
words = await self.bot.guild_config.get_banned_words(ctx.guild.id)
|
||||
|
||||
if not words:
|
||||
await ctx.send("No banned words configured.")
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Banned Words",
|
||||
color=discord.Color.red(),
|
||||
)
|
||||
|
||||
for word in words[:25]: # Discord embed limit
|
||||
word_type = "Regex" if word.is_regex else "Text"
|
||||
embed.add_field(
|
||||
name=f"#{word.id}: {word.pattern[:30]}",
|
||||
value=f"Type: {word_type} | Action: {word.action}",
|
||||
inline=True,
|
||||
)
|
||||
|
||||
if len(words) > 25:
|
||||
embed.set_footer(text=f"Showing 25 of {len(words)} banned words")
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@banned_words.command(name="add")
|
||||
@commands.guild_only()
|
||||
async def banned_words_add(
|
||||
self,
|
||||
ctx: commands.Context,
|
||||
pattern: str,
|
||||
action: Literal["delete", "warn", "strike"] = "delete",
|
||||
is_regex: bool = False,
|
||||
) -> None:
|
||||
"""Add a banned word or pattern."""
|
||||
word = await self.bot.guild_config.add_banned_word(
|
||||
guild_id=ctx.guild.id,
|
||||
pattern=pattern,
|
||||
added_by=ctx.author.id,
|
||||
is_regex=is_regex,
|
||||
action=action,
|
||||
)
|
||||
|
||||
word_type = "regex pattern" if is_regex else "word"
|
||||
await ctx.send(f"Added banned {word_type}: `{pattern}` (ID: {word.id}, Action: {action})")
|
||||
|
||||
@banned_words.command(name="remove", aliases=["delete"])
|
||||
@commands.guild_only()
|
||||
async def banned_words_remove(self, ctx: commands.Context, word_id: int) -> None:
|
||||
"""Remove a banned word by ID."""
|
||||
success = await self.bot.guild_config.remove_banned_word(ctx.guild.id, word_id)
|
||||
|
||||
if success:
|
||||
await ctx.send(f"Removed banned word #{word_id}")
|
||||
else:
|
||||
await ctx.send(f"Banned word #{word_id} not found.")
|
||||
|
||||
@commands.command(name="sync")
|
||||
@commands.is_owner()
|
||||
async def sync_commands(self, ctx: commands.Context) -> None:
|
||||
"""Sync slash commands (bot owner only)."""
|
||||
await self.bot.tree.sync()
|
||||
await ctx.send("Slash commands synced.")
|
||||
|
||||
|
||||
async def setup(bot: GuardDen) -> None:
|
||||
"""Load the Admin cog."""
|
||||
await bot.add_cog(Admin(bot))
|
||||
366
src/guardden/cogs/ai_moderation.py
Normal file
366
src/guardden/cogs/ai_moderation.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""AI-powered moderation cog."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
from guardden.services.ai.base import ContentCategory, ModerationResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# URL pattern for extraction
|
||||
URL_PATTERN = re.compile(
|
||||
r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+[^\s]*",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
class AIModeration(commands.Cog):
|
||||
"""AI-powered content moderation."""
|
||||
|
||||
def __init__(self, bot: GuardDen) -> None:
|
||||
self.bot = bot
|
||||
# Track recently analyzed messages to avoid duplicates (deque auto-removes oldest)
|
||||
self._analyzed_messages: deque[int] = deque(maxlen=1000)
|
||||
|
||||
def _should_analyze(self, message: discord.Message) -> bool:
|
||||
"""Determine if a message should be analyzed by AI."""
|
||||
# Skip if already analyzed
|
||||
if message.id in self._analyzed_messages:
|
||||
return False
|
||||
|
||||
# Skip short messages
|
||||
if len(message.content) < 20 and not message.attachments:
|
||||
return False
|
||||
|
||||
# Skip messages from bots
|
||||
if message.author.bot:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _track_message(self, message_id: int) -> None:
|
||||
"""Track that a message has been analyzed."""
|
||||
self._analyzed_messages.append(message_id)
|
||||
|
||||
async def _handle_ai_result(
|
||||
self,
|
||||
message: discord.Message,
|
||||
result: ModerationResult,
|
||||
analysis_type: str,
|
||||
) -> None:
|
||||
"""Handle the result of AI analysis."""
|
||||
if not result.is_flagged:
|
||||
return
|
||||
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
if not config:
|
||||
return
|
||||
|
||||
# Check if severity meets threshold based on sensitivity
|
||||
# Higher sensitivity = lower threshold needed to trigger
|
||||
threshold = 100 - config.ai_sensitivity # e.g., sensitivity 70 = threshold 30
|
||||
if result.severity < threshold:
|
||||
logger.debug(
|
||||
f"AI flagged content but below threshold: "
|
||||
f"severity={result.severity}, threshold={threshold}"
|
||||
)
|
||||
return
|
||||
|
||||
# Determine action based on suggested action and severity
|
||||
should_delete = result.suggested_action in ("delete", "timeout", "ban")
|
||||
should_timeout = result.suggested_action in ("timeout", "ban") and result.severity > 70
|
||||
|
||||
# Delete message if needed
|
||||
if should_delete:
|
||||
try:
|
||||
await message.delete()
|
||||
except discord.Forbidden:
|
||||
logger.warning(f"Cannot delete message: missing permissions")
|
||||
except discord.NotFound:
|
||||
pass
|
||||
|
||||
# Timeout user for severe violations
|
||||
if should_timeout and isinstance(message.author, discord.Member):
|
||||
timeout_duration = 300 if result.severity < 90 else 3600 # 5 min or 1 hour
|
||||
try:
|
||||
await message.author.timeout(
|
||||
timedelta(seconds=timeout_duration),
|
||||
reason=f"AI Moderation: {result.explanation[:100]}",
|
||||
)
|
||||
except discord.Forbidden:
|
||||
pass
|
||||
|
||||
# Log to mod channel
|
||||
await self._log_ai_action(message, result, analysis_type)
|
||||
|
||||
# Notify user
|
||||
try:
|
||||
embed = discord.Embed(
|
||||
title=f"Message Flagged in {message.guild.name}",
|
||||
description=result.explanation,
|
||||
color=discord.Color.red(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.add_field(
|
||||
name="Categories",
|
||||
value=", ".join(cat.value for cat in result.categories) or "Unknown",
|
||||
)
|
||||
if should_timeout:
|
||||
embed.add_field(name="Action", value="You have been timed out")
|
||||
await message.author.send(embed=embed)
|
||||
except discord.Forbidden:
|
||||
pass
|
||||
|
||||
async def _log_ai_action(
|
||||
self,
|
||||
message: discord.Message,
|
||||
result: ModerationResult,
|
||||
analysis_type: str,
|
||||
) -> None:
|
||||
"""Log an AI moderation action."""
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
if not config or not config.mod_log_channel_id:
|
||||
return
|
||||
|
||||
channel = message.guild.get_channel(config.mod_log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title=f"AI Moderation - {analysis_type}",
|
||||
color=discord.Color.red(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_author(
|
||||
name=str(message.author),
|
||||
icon_url=message.author.display_avatar.url,
|
||||
)
|
||||
|
||||
embed.add_field(name="Confidence", value=f"{result.confidence:.0%}", inline=True)
|
||||
embed.add_field(name="Severity", value=f"{result.severity}/100", inline=True)
|
||||
embed.add_field(name="Action", value=result.suggested_action, inline=True)
|
||||
|
||||
categories = ", ".join(cat.value for cat in result.categories)
|
||||
embed.add_field(name="Categories", value=categories or "None", inline=False)
|
||||
embed.add_field(name="Explanation", value=result.explanation[:500], inline=False)
|
||||
|
||||
if message.content:
|
||||
content = (
|
||||
message.content[:500] + "..." if len(message.content) > 500 else message.content
|
||||
)
|
||||
embed.add_field(name="Content", value=f"```{content}```", inline=False)
|
||||
|
||||
embed.set_footer(text=f"User ID: {message.author.id} | Channel: #{message.channel.name}")
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
"""Analyze messages with AI moderation."""
|
||||
if not message.guild:
|
||||
return
|
||||
|
||||
# Check if AI moderation is enabled for this guild
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
if not config or not config.ai_moderation_enabled:
|
||||
return
|
||||
|
||||
# Skip users with manage_messages permission
|
||||
if isinstance(message.author, discord.Member):
|
||||
if message.author.guild_permissions.manage_messages:
|
||||
return
|
||||
|
||||
if not self._should_analyze(message):
|
||||
return
|
||||
|
||||
self._track_message(message.id)
|
||||
|
||||
# Analyze text content
|
||||
if message.content and len(message.content) >= 20:
|
||||
result = await self.bot.ai_provider.moderate_text(
|
||||
content=message.content,
|
||||
context=f"Discord server: {message.guild.name}, channel: {message.channel.name}",
|
||||
sensitivity=config.ai_sensitivity,
|
||||
)
|
||||
|
||||
if result.is_flagged:
|
||||
await self._handle_ai_result(message, result, "Text Analysis")
|
||||
return # Don't continue if already flagged
|
||||
|
||||
# Analyze images if NSFW detection is enabled (limit to 3 per message)
|
||||
if config.nsfw_detection_enabled and message.attachments:
|
||||
images_analyzed = 0
|
||||
for attachment in message.attachments:
|
||||
if images_analyzed >= 3:
|
||||
break
|
||||
if attachment.content_type and attachment.content_type.startswith("image/"):
|
||||
images_analyzed += 1
|
||||
image_result = await self.bot.ai_provider.analyze_image(
|
||||
image_url=attachment.url,
|
||||
sensitivity=config.ai_sensitivity,
|
||||
)
|
||||
|
||||
if (
|
||||
image_result.is_nsfw
|
||||
or image_result.is_violent
|
||||
or image_result.is_disturbing
|
||||
):
|
||||
# Convert to ModerationResult format
|
||||
categories = []
|
||||
if image_result.is_nsfw:
|
||||
categories.append(ContentCategory.SEXUAL)
|
||||
if image_result.is_violent:
|
||||
categories.append(ContentCategory.VIOLENCE)
|
||||
|
||||
result = ModerationResult(
|
||||
is_flagged=True,
|
||||
confidence=image_result.confidence,
|
||||
categories=categories,
|
||||
explanation=image_result.description,
|
||||
suggested_action="delete",
|
||||
)
|
||||
await self._handle_ai_result(message, result, "Image Analysis")
|
||||
return
|
||||
|
||||
# Analyze URLs for phishing
|
||||
urls = URL_PATTERN.findall(message.content)
|
||||
for url in urls[:3]: # Limit to first 3 URLs
|
||||
phishing_result = await self.bot.ai_provider.analyze_phishing(
|
||||
url=url,
|
||||
message_content=message.content,
|
||||
)
|
||||
|
||||
if phishing_result.is_phishing and phishing_result.confidence > 0.7:
|
||||
result = ModerationResult(
|
||||
is_flagged=True,
|
||||
confidence=phishing_result.confidence,
|
||||
categories=[ContentCategory.SCAM],
|
||||
explanation=phishing_result.explanation,
|
||||
suggested_action="delete",
|
||||
)
|
||||
await self._handle_ai_result(message, result, "Phishing Detection")
|
||||
return
|
||||
|
||||
@commands.group(name="ai", invoke_without_command=True)
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_cmd(self, ctx: commands.Context) -> None:
|
||||
"""View AI moderation settings."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="AI Moderation Settings",
|
||||
color=discord.Color.blue(),
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="AI Moderation",
|
||||
value="✅ Enabled" if config and config.ai_moderation_enabled else "❌ Disabled",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="NSFW Detection",
|
||||
value="✅ Enabled" if config and config.nsfw_detection_enabled else "❌ Disabled",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Sensitivity",
|
||||
value=f"{config.ai_sensitivity}/100" if config else "50/100",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="AI Provider",
|
||||
value=self.bot.settings.ai_provider.capitalize(),
|
||||
inline=True,
|
||||
)
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@ai_cmd.command(name="enable")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_enable(self, ctx: commands.Context) -> None:
|
||||
"""Enable AI moderation."""
|
||||
if self.bot.settings.ai_provider == "none":
|
||||
await ctx.send(
|
||||
"AI moderation is not configured. Set `GUARDDEN_AI_PROVIDER` and API key."
|
||||
)
|
||||
return
|
||||
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, ai_moderation_enabled=True)
|
||||
await ctx.send("✅ AI moderation enabled.")
|
||||
|
||||
@ai_cmd.command(name="disable")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_disable(self, ctx: commands.Context) -> None:
|
||||
"""Disable AI moderation."""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, ai_moderation_enabled=False)
|
||||
await ctx.send("❌ AI moderation disabled.")
|
||||
|
||||
@ai_cmd.command(name="sensitivity")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_sensitivity(self, ctx: commands.Context, level: int) -> None:
|
||||
"""Set AI sensitivity level (0-100). Higher = more strict."""
|
||||
if not 0 <= level <= 100:
|
||||
await ctx.send("Sensitivity must be between 0 and 100.")
|
||||
return
|
||||
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, ai_sensitivity=level)
|
||||
await ctx.send(f"AI sensitivity set to {level}/100.")
|
||||
|
||||
@ai_cmd.command(name="nsfw")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_nsfw(self, ctx: commands.Context, enabled: bool) -> None:
|
||||
"""Enable or disable NSFW image detection."""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, nsfw_detection_enabled=enabled)
|
||||
status = "enabled" if enabled else "disabled"
|
||||
await ctx.send(f"NSFW detection {status}.")
|
||||
|
||||
@ai_cmd.command(name="analyze")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def ai_analyze(self, ctx: commands.Context, *, text: str) -> None:
|
||||
"""Test AI analysis on text (does not take action)."""
|
||||
if self.bot.settings.ai_provider == "none":
|
||||
await ctx.send("AI moderation is not configured.")
|
||||
return
|
||||
|
||||
async with ctx.typing():
|
||||
result = await self.bot.ai_provider.moderate_text(
|
||||
content=text,
|
||||
context=f"Test analysis in {ctx.guild.name}",
|
||||
sensitivity=50,
|
||||
)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="AI Analysis Result",
|
||||
color=discord.Color.red() if result.is_flagged else discord.Color.green(),
|
||||
)
|
||||
|
||||
embed.add_field(name="Flagged", value="Yes" if result.is_flagged else "No", inline=True)
|
||||
embed.add_field(name="Confidence", value=f"{result.confidence:.0%}", inline=True)
|
||||
embed.add_field(name="Severity", value=f"{result.severity}/100", inline=True)
|
||||
embed.add_field(name="Suggested Action", value=result.suggested_action, inline=True)
|
||||
|
||||
if result.categories:
|
||||
categories = ", ".join(cat.value for cat in result.categories)
|
||||
embed.add_field(name="Categories", value=categories, inline=False)
|
||||
|
||||
if result.explanation:
|
||||
embed.add_field(name="Explanation", value=result.explanation[:1000], inline=False)
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
|
||||
async def setup(bot: GuardDen) -> None:
|
||||
"""Load the AI Moderation cog."""
|
||||
await bot.add_cog(AIModeration(bot))
|
||||
267
src/guardden/cogs/automod.py
Normal file
267
src/guardden/cogs/automod.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Automod cog for automatic content moderation."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
from guardden.services.automod import AutomodResult, AutomodService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Automod(commands.Cog):
|
||||
"""Automatic content moderation."""
|
||||
|
||||
def __init__(self, bot: GuardDen) -> None:
|
||||
self.bot = bot
|
||||
self.automod = AutomodService()
|
||||
|
||||
async def _handle_violation(
|
||||
self,
|
||||
message: discord.Message,
|
||||
result: AutomodResult,
|
||||
) -> None:
|
||||
"""Handle an automod violation."""
|
||||
# Delete the message
|
||||
if result.should_delete:
|
||||
try:
|
||||
await message.delete()
|
||||
except discord.Forbidden:
|
||||
logger.warning(f"Cannot delete message in {message.guild}: missing permissions")
|
||||
except discord.NotFound:
|
||||
pass # Already deleted
|
||||
|
||||
# Apply timeout
|
||||
if result.should_timeout and result.timeout_duration > 0:
|
||||
try:
|
||||
await message.author.timeout(
|
||||
timedelta(seconds=result.timeout_duration),
|
||||
reason=f"Automod: {result.reason}",
|
||||
)
|
||||
except discord.Forbidden:
|
||||
logger.warning(f"Cannot timeout {message.author}: missing permissions")
|
||||
|
||||
# Log the action
|
||||
await self._log_automod_action(message, result)
|
||||
|
||||
# Notify the user via DM
|
||||
try:
|
||||
embed = discord.Embed(
|
||||
title=f"Message Removed in {message.guild.name}",
|
||||
description=result.reason,
|
||||
color=discord.Color.orange(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
if result.should_timeout:
|
||||
embed.add_field(
|
||||
name="Timeout",
|
||||
value=f"You have been timed out for {result.timeout_duration} seconds.",
|
||||
)
|
||||
await message.author.send(embed=embed)
|
||||
except discord.Forbidden:
|
||||
pass # User has DMs disabled
|
||||
|
||||
async def _log_automod_action(
|
||||
self,
|
||||
message: discord.Message,
|
||||
result: AutomodResult,
|
||||
) -> None:
|
||||
"""Log an automod action to the mod log channel."""
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
if not config or not config.mod_log_channel_id:
|
||||
return
|
||||
|
||||
channel = message.guild.get_channel(config.mod_log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Automod Action",
|
||||
color=discord.Color.orange(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_author(
|
||||
name=str(message.author),
|
||||
icon_url=message.author.display_avatar.url,
|
||||
)
|
||||
embed.add_field(name="Filter", value=result.matched_filter, inline=True)
|
||||
embed.add_field(name="Channel", value=message.channel.mention, inline=True)
|
||||
embed.add_field(name="Reason", value=result.reason, inline=False)
|
||||
|
||||
if message.content:
|
||||
content = (
|
||||
message.content[:500] + "..." if len(message.content) > 500 else message.content
|
||||
)
|
||||
embed.add_field(name="Message Content", value=f"```{content}```", inline=False)
|
||||
|
||||
actions = []
|
||||
if result.should_delete:
|
||||
actions.append("Message deleted")
|
||||
if result.should_warn:
|
||||
actions.append("User warned")
|
||||
if result.should_strike:
|
||||
actions.append("Strike added")
|
||||
if result.should_timeout:
|
||||
actions.append(f"Timeout ({result.timeout_duration}s)")
|
||||
|
||||
embed.add_field(name="Actions Taken", value=", ".join(actions) or "None", inline=False)
|
||||
embed.set_footer(text=f"User ID: {message.author.id}")
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
"""Check all messages for automod violations."""
|
||||
# Ignore DMs, bots, and empty messages
|
||||
if not message.guild or message.author.bot or not message.content:
|
||||
return
|
||||
|
||||
# Ignore users with manage_messages permission
|
||||
if isinstance(message.author, discord.Member):
|
||||
if message.author.guild_permissions.manage_messages:
|
||||
return
|
||||
|
||||
# Get guild config
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
if not config or not config.automod_enabled:
|
||||
return
|
||||
|
||||
result: AutomodResult | None = None
|
||||
|
||||
# Check banned words
|
||||
banned_words = await self.bot.guild_config.get_banned_words(message.guild.id)
|
||||
if banned_words:
|
||||
result = self.automod.check_banned_words(message.content, banned_words)
|
||||
|
||||
# Check scam links (if link filter enabled)
|
||||
if not result and config.link_filter_enabled:
|
||||
result = self.automod.check_scam_links(message.content)
|
||||
|
||||
# Check spam
|
||||
if not result and config.anti_spam_enabled:
|
||||
result = self.automod.check_spam(message, anti_spam_enabled=True)
|
||||
|
||||
# Check invite links (if link filter enabled)
|
||||
if not result and config.link_filter_enabled:
|
||||
result = self.automod.check_invite_links(message.content, allow_invites=False)
|
||||
|
||||
# Handle violation if found
|
||||
if result:
|
||||
logger.info(
|
||||
f"Automod triggered in {message.guild.name}: "
|
||||
f"{result.matched_filter} by {message.author}"
|
||||
)
|
||||
await self._handle_violation(message, result)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_message_edit(self, before: discord.Message, after: discord.Message) -> None:
|
||||
"""Check edited messages for automod violations."""
|
||||
# Only check if content changed
|
||||
if before.content == after.content:
|
||||
return
|
||||
|
||||
# Reuse on_message logic
|
||||
await self.on_message(after)
|
||||
|
||||
@commands.group(name="automod", invoke_without_command=True)
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def automod_cmd(self, ctx: commands.Context) -> None:
|
||||
"""View automod status and configuration."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Automod Configuration",
|
||||
color=discord.Color.blue(),
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="Automod Enabled",
|
||||
value="✅ Yes" if config and config.automod_enabled else "❌ No",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Anti-Spam",
|
||||
value="✅ Yes" if config and config.anti_spam_enabled else "❌ No",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Link Filter",
|
||||
value="✅ Yes" if config and config.link_filter_enabled else "❌ No",
|
||||
inline=True,
|
||||
)
|
||||
|
||||
# Show thresholds
|
||||
embed.add_field(
|
||||
name="Rate Limit",
|
||||
value=f"{self.automod.message_rate_limit} msgs / {self.automod.message_rate_window}s",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Duplicate Threshold",
|
||||
value=f"{self.automod.duplicate_threshold} same messages",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Mention Limit",
|
||||
value=f"{self.automod.mention_limit} per message",
|
||||
inline=True,
|
||||
)
|
||||
|
||||
banned_words = await self.bot.guild_config.get_banned_words(ctx.guild.id)
|
||||
embed.add_field(
|
||||
name="Banned Words",
|
||||
value=f"{len(banned_words)} configured",
|
||||
inline=True,
|
||||
)
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@automod_cmd.command(name="test")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def automod_test(self, ctx: commands.Context, *, text: str) -> None:
|
||||
"""Test a message against automod filters (does not take action)."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
results = []
|
||||
|
||||
# Check banned words
|
||||
banned_words = await self.bot.guild_config.get_banned_words(ctx.guild.id)
|
||||
result = self.automod.check_banned_words(text, banned_words)
|
||||
if result:
|
||||
results.append(f"**Banned Words**: {result.reason}")
|
||||
|
||||
# Check scam links
|
||||
result = self.automod.check_scam_links(text)
|
||||
if result:
|
||||
results.append(f"**Scam Detection**: {result.reason}")
|
||||
|
||||
# Check invite links
|
||||
result = self.automod.check_invite_links(text, allow_invites=False)
|
||||
if result:
|
||||
results.append(f"**Invite Links**: {result.reason}")
|
||||
|
||||
# Check caps
|
||||
result = self.automod.check_all_caps(text)
|
||||
if result:
|
||||
results.append(f"**Excessive Caps**: {result.reason}")
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Automod Test Results",
|
||||
color=discord.Color.red() if results else discord.Color.green(),
|
||||
)
|
||||
|
||||
if results:
|
||||
embed.description = "\n".join(results)
|
||||
else:
|
||||
embed.description = "✅ No violations detected"
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
|
||||
async def setup(bot: GuardDen) -> None:
|
||||
"""Load the Automod cog."""
|
||||
await bot.add_cog(Automod(bot))
|
||||
237
src/guardden/cogs/events.py
Normal file
237
src/guardden/cogs/events.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Event handlers for logging and monitoring."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Events(commands.Cog):
|
||||
"""Handles Discord events for logging and monitoring."""
|
||||
|
||||
def __init__(self, bot: GuardDen) -> None:
|
||||
self.bot = bot
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_member_join(self, member: discord.Member) -> None:
|
||||
"""Called when a member joins a guild."""
|
||||
logger.debug(f"Member joined: {member} in {member.guild}")
|
||||
|
||||
config = await self.bot.guild_config.get_config(member.guild.id)
|
||||
if not config or not config.log_channel_id:
|
||||
return
|
||||
|
||||
channel = member.guild.get_channel(config.log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Member Joined",
|
||||
description=f"{member.mention} ({member})",
|
||||
color=discord.Color.green(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_thumbnail(url=member.display_avatar.url)
|
||||
embed.add_field(
|
||||
name="Account Created", value=discord.utils.format_dt(member.created_at, "R")
|
||||
)
|
||||
embed.add_field(name="Member ID", value=str(member.id))
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_member_remove(self, member: discord.Member) -> None:
|
||||
"""Called when a member leaves a guild."""
|
||||
logger.debug(f"Member left: {member} from {member.guild}")
|
||||
|
||||
config = await self.bot.guild_config.get_config(member.guild.id)
|
||||
if not config or not config.log_channel_id:
|
||||
return
|
||||
|
||||
channel = member.guild.get_channel(config.log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Member Left",
|
||||
description=f"{member} ({member.id})",
|
||||
color=discord.Color.orange(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_thumbnail(url=member.display_avatar.url)
|
||||
|
||||
if member.joined_at:
|
||||
embed.add_field(name="Joined", value=discord.utils.format_dt(member.joined_at, "R"))
|
||||
|
||||
roles = [r.mention for r in member.roles if r != member.guild.default_role]
|
||||
if roles:
|
||||
embed.add_field(name="Roles", value=", ".join(roles[:10]), inline=False)
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_message_delete(self, message: discord.Message) -> None:
|
||||
"""Called when a message is deleted."""
|
||||
if message.author.bot or not message.guild:
|
||||
return
|
||||
|
||||
config = await self.bot.guild_config.get_config(message.guild.id)
|
||||
if not config or not config.log_channel_id:
|
||||
return
|
||||
|
||||
channel = message.guild.get_channel(config.log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Message Deleted",
|
||||
description=f"In {message.channel.mention}",
|
||||
color=discord.Color.red(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_author(name=str(message.author), icon_url=message.author.display_avatar.url)
|
||||
|
||||
if message.content:
|
||||
content = message.content[:1024] if len(message.content) > 1024 else message.content
|
||||
embed.add_field(name="Content", value=content, inline=False)
|
||||
|
||||
if message.attachments:
|
||||
attachments = "\n".join(a.filename for a in message.attachments)
|
||||
embed.add_field(name="Attachments", value=attachments, inline=False)
|
||||
|
||||
embed.set_footer(text=f"Author ID: {message.author.id} | Message ID: {message.id}")
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_message_edit(self, before: discord.Message, after: discord.Message) -> None:
|
||||
"""Called when a message is edited."""
|
||||
if before.author.bot or not before.guild:
|
||||
return
|
||||
|
||||
if before.content == after.content:
|
||||
return
|
||||
|
||||
config = await self.bot.guild_config.get_config(before.guild.id)
|
||||
if not config or not config.log_channel_id:
|
||||
return
|
||||
|
||||
channel = before.guild.get_channel(config.log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Message Edited",
|
||||
description=f"In {before.channel.mention} | [Jump to message]({after.jump_url})",
|
||||
color=discord.Color.blue(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_author(name=str(before.author), icon_url=before.author.display_avatar.url)
|
||||
|
||||
before_content = before.content[:1024] if len(before.content) > 1024 else before.content
|
||||
after_content = after.content[:1024] if len(after.content) > 1024 else after.content
|
||||
|
||||
embed.add_field(name="Before", value=before_content or "*empty*", inline=False)
|
||||
embed.add_field(name="After", value=after_content or "*empty*", inline=False)
|
||||
embed.set_footer(text=f"Author ID: {before.author.id}")
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_voice_state_update(
|
||||
self,
|
||||
member: discord.Member,
|
||||
before: discord.VoiceState,
|
||||
after: discord.VoiceState,
|
||||
) -> None:
|
||||
"""Called when a member's voice state changes."""
|
||||
if member.bot:
|
||||
return
|
||||
|
||||
config = await self.bot.guild_config.get_config(member.guild.id)
|
||||
if not config or not config.log_channel_id:
|
||||
return
|
||||
|
||||
channel = member.guild.get_channel(config.log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = None
|
||||
|
||||
if before.channel is None and after.channel is not None:
|
||||
embed = discord.Embed(
|
||||
title="Voice Channel Joined",
|
||||
description=f"{member.mention} joined {after.channel.mention}",
|
||||
color=discord.Color.green(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
elif before.channel is not None and after.channel is None:
|
||||
embed = discord.Embed(
|
||||
title="Voice Channel Left",
|
||||
description=f"{member.mention} left {before.channel.mention}",
|
||||
color=discord.Color.orange(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
elif before.channel != after.channel and before.channel and after.channel:
|
||||
embed = discord.Embed(
|
||||
title="Voice Channel Moved",
|
||||
description=f"{member.mention} moved from {before.channel.mention} to {after.channel.mention}",
|
||||
color=discord.Color.blue(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
if embed:
|
||||
embed.set_author(name=str(member), icon_url=member.display_avatar.url)
|
||||
await channel.send(embed=embed)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_member_ban(self, guild: discord.Guild, user: discord.User) -> None:
|
||||
"""Called when a user is banned."""
|
||||
config = await self.bot.guild_config.get_config(guild.id)
|
||||
if not config or not config.mod_log_channel_id:
|
||||
return
|
||||
|
||||
channel = guild.get_channel(config.mod_log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Member Banned",
|
||||
description=f"{user} ({user.id})",
|
||||
color=discord.Color.dark_red(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_thumbnail(url=user.display_avatar.url)
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_member_unban(self, guild: discord.Guild, user: discord.User) -> None:
|
||||
"""Called when a user is unbanned."""
|
||||
config = await self.bot.guild_config.get_config(guild.id)
|
||||
if not config or not config.mod_log_channel_id:
|
||||
return
|
||||
|
||||
channel = guild.get_channel(config.mod_log_channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Member Unbanned",
|
||||
description=f"{user} ({user.id})",
|
||||
color=discord.Color.green(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_thumbnail(url=user.display_avatar.url)
|
||||
|
||||
await channel.send(embed=embed)
|
||||
|
||||
|
||||
async def setup(bot: GuardDen) -> None:
|
||||
"""Load the Events cog."""
|
||||
await bot.add_cog(Events(bot))
|
||||
466
src/guardden/cogs/moderation.py
Normal file
466
src/guardden/cogs/moderation.py
Normal file
@@ -0,0 +1,466 @@
|
||||
"""Moderation commands and automod features."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
from guardden.models import ModerationLog, Strike
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_duration(duration_str: str) -> timedelta | None:
|
||||
"""Parse a duration string like '1h', '30m', '7d' into a timedelta."""
|
||||
match = re.match(r"^(\d+)([smhdw])$", duration_str.lower())
|
||||
if not match:
|
||||
return None
|
||||
|
||||
amount = int(match.group(1))
|
||||
unit = match.group(2)
|
||||
|
||||
units = {
|
||||
"s": timedelta(seconds=amount),
|
||||
"m": timedelta(minutes=amount),
|
||||
"h": timedelta(hours=amount),
|
||||
"d": timedelta(days=amount),
|
||||
"w": timedelta(weeks=amount),
|
||||
}
|
||||
|
||||
return units.get(unit)
|
||||
|
||||
|
||||
class Moderation(commands.Cog):
|
||||
"""Moderation commands for server management."""
|
||||
|
||||
def __init__(self, bot: GuardDen) -> None:
|
||||
self.bot = bot
|
||||
|
||||
async def _log_action(
|
||||
self,
|
||||
guild: discord.Guild,
|
||||
target: discord.Member | discord.User,
|
||||
moderator: discord.Member | discord.User,
|
||||
action: str,
|
||||
reason: str | None = None,
|
||||
duration: int | None = None,
|
||||
channel: discord.TextChannel | None = None,
|
||||
message: discord.Message | None = None,
|
||||
is_automatic: bool = False,
|
||||
) -> None:
|
||||
"""Log a moderation action to the database."""
|
||||
expires_at = None
|
||||
if duration:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=duration)
|
||||
|
||||
async with self.bot.database.session() as session:
|
||||
log_entry = ModerationLog(
|
||||
guild_id=guild.id,
|
||||
target_id=target.id,
|
||||
target_name=str(target),
|
||||
moderator_id=moderator.id,
|
||||
moderator_name=str(moderator),
|
||||
action=action,
|
||||
reason=reason,
|
||||
duration=duration,
|
||||
expires_at=expires_at,
|
||||
channel_id=channel.id if channel else None,
|
||||
message_id=message.id if message else None,
|
||||
message_content=message.content if message else None,
|
||||
is_automatic=is_automatic,
|
||||
)
|
||||
session.add(log_entry)
|
||||
|
||||
async def _get_strike_count(self, guild_id: int, user_id: int) -> int:
|
||||
"""Get the total active strike count for a user."""
|
||||
async with self.bot.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(func.sum(Strike.points)).where(
|
||||
Strike.guild_id == guild_id,
|
||||
Strike.user_id == user_id,
|
||||
Strike.is_active == True,
|
||||
)
|
||||
)
|
||||
total = result.scalar()
|
||||
return total or 0
|
||||
|
||||
async def _add_strike(
|
||||
self,
|
||||
guild: discord.Guild,
|
||||
user: discord.Member,
|
||||
moderator: discord.Member | discord.User,
|
||||
reason: str,
|
||||
points: int = 1,
|
||||
) -> int:
|
||||
"""Add a strike to a user and return their new total."""
|
||||
async with self.bot.database.session() as session:
|
||||
strike = Strike(
|
||||
guild_id=guild.id,
|
||||
user_id=user.id,
|
||||
user_name=str(user),
|
||||
moderator_id=moderator.id,
|
||||
reason=reason,
|
||||
points=points,
|
||||
)
|
||||
session.add(strike)
|
||||
|
||||
return await self._get_strike_count(guild.id, user.id)
|
||||
|
||||
@commands.command(name="warn")
|
||||
@commands.has_permissions(kick_members=True)
|
||||
@commands.guild_only()
|
||||
async def warn(
|
||||
self, ctx: commands.Context, member: discord.Member, *, reason: str = "No reason provided"
|
||||
) -> None:
|
||||
"""Warn a member."""
|
||||
if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner:
|
||||
await ctx.send("You cannot warn someone with a higher or equal role.")
|
||||
return
|
||||
|
||||
await self._log_action(ctx.guild, member, ctx.author, "warn", reason)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Warning Issued",
|
||||
description=f"{member.mention} has been warned.",
|
||||
color=discord.Color.yellow(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.add_field(name="Reason", value=reason, inline=False)
|
||||
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
# Try to DM the user
|
||||
try:
|
||||
dm_embed = discord.Embed(
|
||||
title=f"Warning in {ctx.guild.name}",
|
||||
description=f"You have been warned.",
|
||||
color=discord.Color.yellow(),
|
||||
)
|
||||
dm_embed.add_field(name="Reason", value=reason)
|
||||
await member.send(embed=dm_embed)
|
||||
except discord.Forbidden:
|
||||
pass
|
||||
|
||||
@commands.command(name="strike")
|
||||
@commands.has_permissions(kick_members=True)
|
||||
@commands.guild_only()
|
||||
async def strike(
|
||||
self,
|
||||
ctx: commands.Context,
|
||||
member: discord.Member,
|
||||
points: int = 1,
|
||||
*,
|
||||
reason: str = "No reason provided",
|
||||
) -> None:
|
||||
"""Add a strike to a member."""
|
||||
if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner:
|
||||
await ctx.send("You cannot strike someone with a higher or equal role.")
|
||||
return
|
||||
|
||||
total_strikes = await self._add_strike(ctx.guild, member, ctx.author, reason, points)
|
||||
await self._log_action(ctx.guild, member, ctx.author, "strike", reason)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Strike Added",
|
||||
description=f"{member.mention} has received {points} strike(s).",
|
||||
color=discord.Color.orange(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.add_field(name="Reason", value=reason, inline=False)
|
||||
embed.add_field(name="Total Strikes", value=str(total_strikes))
|
||||
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
# Check for automatic actions based on strike thresholds
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
if config and config.strike_actions:
|
||||
for threshold, action_config in sorted(
|
||||
config.strike_actions.items(), key=lambda x: int(x[0]), reverse=True
|
||||
):
|
||||
if total_strikes >= int(threshold):
|
||||
action = action_config.get("action")
|
||||
if action == "ban":
|
||||
await ctx.invoke(
|
||||
self.ban, member=member, reason=f"Automatic: {total_strikes} strikes"
|
||||
)
|
||||
elif action == "kick":
|
||||
await ctx.invoke(
|
||||
self.kick, member=member, reason=f"Automatic: {total_strikes} strikes"
|
||||
)
|
||||
elif action == "timeout":
|
||||
duration = action_config.get("duration", 3600)
|
||||
await ctx.invoke(
|
||||
self.timeout,
|
||||
member=member,
|
||||
duration=f"{duration}s",
|
||||
reason=f"Automatic: {total_strikes} strikes",
|
||||
)
|
||||
break
|
||||
|
||||
@commands.command(name="strikes")
|
||||
@commands.has_permissions(kick_members=True)
|
||||
@commands.guild_only()
|
||||
async def strikes(self, ctx: commands.Context, member: discord.Member) -> None:
|
||||
"""View strikes for a member."""
|
||||
async with self.bot.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(Strike)
|
||||
.where(
|
||||
Strike.guild_id == ctx.guild.id,
|
||||
Strike.user_id == member.id,
|
||||
Strike.is_active == True,
|
||||
)
|
||||
.order_by(Strike.created_at.desc())
|
||||
.limit(10)
|
||||
)
|
||||
user_strikes = result.scalars().all()
|
||||
|
||||
total = await self._get_strike_count(ctx.guild.id, member.id)
|
||||
|
||||
embed = discord.Embed(
|
||||
title=f"Strikes for {member}",
|
||||
description=f"Total active strikes: **{total}**",
|
||||
color=discord.Color.orange(),
|
||||
)
|
||||
|
||||
if user_strikes:
|
||||
for strike in user_strikes:
|
||||
embed.add_field(
|
||||
name=f"Strike #{strike.id} ({strike.points} pts)",
|
||||
value=f"{strike.reason}\n*{strike.created_at.strftime('%Y-%m-%d')}*",
|
||||
inline=False,
|
||||
)
|
||||
else:
|
||||
embed.description = f"{member.mention} has no active strikes."
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@commands.command(name="timeout", aliases=["mute"])
|
||||
@commands.has_permissions(moderate_members=True)
|
||||
@commands.guild_only()
|
||||
async def timeout(
|
||||
self,
|
||||
ctx: commands.Context,
|
||||
member: discord.Member,
|
||||
duration: str = "1h",
|
||||
*,
|
||||
reason: str = "No reason provided",
|
||||
) -> None:
|
||||
"""Timeout a member (e.g., !timeout @user 1h Spamming)."""
|
||||
if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner:
|
||||
await ctx.send("You cannot timeout someone with a higher or equal role.")
|
||||
return
|
||||
|
||||
delta = parse_duration(duration)
|
||||
if not delta:
|
||||
await ctx.send("Invalid duration. Use format like: 30m, 1h, 7d")
|
||||
return
|
||||
|
||||
if delta > timedelta(days=28):
|
||||
await ctx.send("Timeout duration cannot exceed 28 days.")
|
||||
return
|
||||
|
||||
try:
|
||||
await member.timeout(delta, reason=f"{ctx.author}: {reason}")
|
||||
except discord.Forbidden:
|
||||
await ctx.send("I don't have permission to timeout this user.")
|
||||
return
|
||||
except discord.HTTPException as e:
|
||||
await ctx.send(f"Failed to timeout user: {e}")
|
||||
return
|
||||
|
||||
await self._log_action(
|
||||
ctx.guild, member, ctx.author, "timeout", reason, int(delta.total_seconds())
|
||||
)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Member Timed Out",
|
||||
description=f"{member.mention} has been timed out for {duration}.",
|
||||
color=discord.Color.orange(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.add_field(name="Reason", value=reason, inline=False)
|
||||
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@commands.command(name="untimeout", aliases=["unmute"])
|
||||
@commands.has_permissions(moderate_members=True)
|
||||
@commands.guild_only()
|
||||
async def untimeout(
|
||||
self, ctx: commands.Context, member: discord.Member, *, reason: str = "No reason provided"
|
||||
) -> None:
|
||||
"""Remove timeout from a member."""
|
||||
await member.timeout(None, reason=f"{ctx.author}: {reason}")
|
||||
await self._log_action(ctx.guild, member, ctx.author, "unmute", reason)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Timeout Removed",
|
||||
description=f"{member.mention}'s timeout has been removed.",
|
||||
color=discord.Color.green(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.add_field(name="Reason", value=reason, inline=False)
|
||||
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@commands.command(name="kick")
|
||||
@commands.has_permissions(kick_members=True)
|
||||
@commands.guild_only()
|
||||
async def kick(
|
||||
self, ctx: commands.Context, member: discord.Member, *, reason: str = "No reason provided"
|
||||
) -> None:
|
||||
"""Kick a member from the server."""
|
||||
if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner:
|
||||
await ctx.send("You cannot kick someone with a higher or equal role.")
|
||||
return
|
||||
|
||||
# Try to DM the user before kicking
|
||||
try:
|
||||
dm_embed = discord.Embed(
|
||||
title=f"Kicked from {ctx.guild.name}",
|
||||
description=f"You have been kicked from the server.",
|
||||
color=discord.Color.red(),
|
||||
)
|
||||
dm_embed.add_field(name="Reason", value=reason)
|
||||
await member.send(embed=dm_embed)
|
||||
except discord.Forbidden:
|
||||
pass
|
||||
|
||||
await member.kick(reason=f"{ctx.author}: {reason}")
|
||||
await self._log_action(ctx.guild, member, ctx.author, "kick", reason)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Member Kicked",
|
||||
description=f"{member} has been kicked from the server.",
|
||||
color=discord.Color.red(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.add_field(name="Reason", value=reason, inline=False)
|
||||
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@commands.command(name="ban")
|
||||
@commands.has_permissions(ban_members=True)
|
||||
@commands.guild_only()
|
||||
async def ban(
|
||||
self,
|
||||
ctx: commands.Context,
|
||||
member: discord.Member | discord.User,
|
||||
*,
|
||||
reason: str = "No reason provided",
|
||||
) -> None:
|
||||
"""Ban a member from the server."""
|
||||
if isinstance(member, discord.Member):
|
||||
if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner:
|
||||
await ctx.send("You cannot ban someone with a higher or equal role.")
|
||||
return
|
||||
|
||||
# Try to DM the user before banning
|
||||
try:
|
||||
dm_embed = discord.Embed(
|
||||
title=f"Banned from {ctx.guild.name}",
|
||||
description=f"You have been banned from the server.",
|
||||
color=discord.Color.dark_red(),
|
||||
)
|
||||
dm_embed.add_field(name="Reason", value=reason)
|
||||
await member.send(embed=dm_embed)
|
||||
except discord.Forbidden:
|
||||
pass
|
||||
|
||||
await ctx.guild.ban(member, reason=f"{ctx.author}: {reason}", delete_message_days=0)
|
||||
await self._log_action(ctx.guild, member, ctx.author, "ban", reason)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Member Banned",
|
||||
description=f"{member} has been banned from the server.",
|
||||
color=discord.Color.dark_red(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.add_field(name="Reason", value=reason, inline=False)
|
||||
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@commands.command(name="unban")
|
||||
@commands.has_permissions(ban_members=True)
|
||||
@commands.guild_only()
|
||||
async def unban(
|
||||
self, ctx: commands.Context, user_id: int, *, reason: str = "No reason provided"
|
||||
) -> None:
|
||||
"""Unban a user by their ID."""
|
||||
try:
|
||||
user = await self.bot.fetch_user(user_id)
|
||||
await ctx.guild.unban(user, reason=f"{ctx.author}: {reason}")
|
||||
await self._log_action(ctx.guild, user, ctx.author, "unban", reason)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="User Unbanned",
|
||||
description=f"{user} has been unbanned.",
|
||||
color=discord.Color.green(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.add_field(name="Reason", value=reason, inline=False)
|
||||
embed.set_footer(text=f"Moderator: {ctx.author}")
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
except discord.NotFound:
|
||||
await ctx.send("User not found or not banned.")
|
||||
except discord.Forbidden:
|
||||
await ctx.send("I don't have permission to unban this user.")
|
||||
|
||||
@commands.command(name="purge", aliases=["clear"])
|
||||
@commands.has_permissions(manage_messages=True)
|
||||
@commands.guild_only()
|
||||
async def purge(self, ctx: commands.Context, amount: int) -> None:
|
||||
"""Delete multiple messages at once (max 100)."""
|
||||
if amount < 1 or amount > 100:
|
||||
await ctx.send("Please specify a number between 1 and 100.")
|
||||
return
|
||||
|
||||
deleted = await ctx.channel.purge(limit=amount + 1) # +1 to include the command message
|
||||
|
||||
msg = await ctx.send(f"Deleted {len(deleted) - 1} message(s).")
|
||||
await msg.delete(delay=3)
|
||||
|
||||
@commands.command(name="modlogs", aliases=["history"])
|
||||
@commands.has_permissions(kick_members=True)
|
||||
@commands.guild_only()
|
||||
async def modlogs(self, ctx: commands.Context, member: discord.Member | discord.User) -> None:
|
||||
"""View moderation history for a user."""
|
||||
async with self.bot.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(ModerationLog)
|
||||
.where(ModerationLog.guild_id == ctx.guild.id, ModerationLog.target_id == member.id)
|
||||
.order_by(ModerationLog.created_at.desc())
|
||||
.limit(10)
|
||||
)
|
||||
logs = result.scalars().all()
|
||||
|
||||
embed = discord.Embed(
|
||||
title=f"Moderation History for {member}",
|
||||
color=discord.Color.blue(),
|
||||
)
|
||||
|
||||
if logs:
|
||||
for log in logs:
|
||||
value = f"**Reason:** {log.reason or 'None'}\n**By:** {log.moderator_name}\n*{log.created_at.strftime('%Y-%m-%d %H:%M')}*"
|
||||
embed.add_field(name=f"{log.action.upper()} (#{log.id})", value=value, inline=False)
|
||||
else:
|
||||
embed.description = "No moderation history found."
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
|
||||
async def setup(bot: GuardDen) -> None:
|
||||
"""Load the Moderation cog."""
|
||||
await bot.add_cog(Moderation(bot))
|
||||
423
src/guardden/cogs/verification.py
Normal file
423
src/guardden/cogs/verification.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""Verification cog for new member verification."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import discord
|
||||
from discord import ui
|
||||
from discord.ext import commands, tasks
|
||||
|
||||
from guardden.bot import GuardDen
|
||||
from guardden.services.verification import (
|
||||
ChallengeType,
|
||||
PendingVerification,
|
||||
VerificationService,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VerifyButton(ui.Button["VerificationView"]):
|
||||
"""Button for simple verification."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
style=discord.ButtonStyle.success,
|
||||
label="Verify",
|
||||
custom_id="verify_button",
|
||||
)
|
||||
|
||||
async def callback(self, interaction: discord.Interaction) -> None:
|
||||
if self.view is None:
|
||||
return
|
||||
|
||||
success, message = await self.view.cog.complete_verification(
|
||||
interaction.guild.id,
|
||||
interaction.user.id,
|
||||
"verified",
|
||||
)
|
||||
|
||||
if success:
|
||||
await interaction.response.send_message(message, ephemeral=True)
|
||||
# Disable the button
|
||||
self.disabled = True
|
||||
self.label = "Verified"
|
||||
await interaction.message.edit(view=self.view)
|
||||
else:
|
||||
await interaction.response.send_message(message, ephemeral=True)
|
||||
|
||||
|
||||
class EmojiButton(ui.Button["EmojiVerificationView"]):
|
||||
"""Button for emoji selection verification."""
|
||||
|
||||
def __init__(self, emoji: str, row: int = 0) -> None:
|
||||
super().__init__(
|
||||
style=discord.ButtonStyle.secondary,
|
||||
label=emoji,
|
||||
custom_id=f"emoji_{emoji}",
|
||||
row=row,
|
||||
)
|
||||
self.emoji_value = emoji
|
||||
|
||||
async def callback(self, interaction: discord.Interaction) -> None:
|
||||
if self.view is None:
|
||||
return
|
||||
|
||||
success, message = await self.view.cog.complete_verification(
|
||||
interaction.guild.id,
|
||||
interaction.user.id,
|
||||
self.emoji_value,
|
||||
)
|
||||
|
||||
if success:
|
||||
await interaction.response.send_message(message, ephemeral=True)
|
||||
# Disable all buttons
|
||||
for item in self.view.children:
|
||||
if isinstance(item, ui.Button):
|
||||
item.disabled = True
|
||||
await interaction.message.edit(view=self.view)
|
||||
else:
|
||||
await interaction.response.send_message(message, ephemeral=True)
|
||||
|
||||
|
||||
class VerificationView(ui.View):
|
||||
"""View for button verification."""
|
||||
|
||||
def __init__(self, cog: "Verification", timeout: float = 600) -> None:
|
||||
super().__init__(timeout=timeout)
|
||||
self.cog = cog
|
||||
self.add_item(VerifyButton())
|
||||
|
||||
|
||||
class EmojiVerificationView(ui.View):
|
||||
"""View for emoji selection verification."""
|
||||
|
||||
def __init__(self, cog: "Verification", options: list[str], timeout: float = 600) -> None:
|
||||
super().__init__(timeout=timeout)
|
||||
self.cog = cog
|
||||
for i, emoji in enumerate(options):
|
||||
self.add_item(EmojiButton(emoji, row=i // 4))
|
||||
|
||||
|
||||
class CaptchaModal(ui.Modal):
|
||||
"""Modal for captcha/math input."""
|
||||
|
||||
answer = ui.TextInput(
|
||||
label="Your Answer",
|
||||
placeholder="Enter the answer here...",
|
||||
max_length=50,
|
||||
)
|
||||
|
||||
def __init__(self, cog: "Verification", title: str = "Verification") -> None:
|
||||
super().__init__(title=title)
|
||||
self.cog = cog
|
||||
|
||||
async def on_submit(self, interaction: discord.Interaction) -> None:
|
||||
success, message = await self.cog.complete_verification(
|
||||
interaction.guild.id,
|
||||
interaction.user.id,
|
||||
self.answer.value,
|
||||
)
|
||||
await interaction.response.send_message(message, ephemeral=True)
|
||||
|
||||
|
||||
class AnswerButton(ui.Button["AnswerView"]):
|
||||
"""Button to open the answer modal."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
style=discord.ButtonStyle.primary,
|
||||
label="Submit Answer",
|
||||
custom_id="submit_answer",
|
||||
)
|
||||
|
||||
async def callback(self, interaction: discord.Interaction) -> None:
|
||||
if self.view is None:
|
||||
return
|
||||
modal = CaptchaModal(self.view.cog)
|
||||
await interaction.response.send_modal(modal)
|
||||
|
||||
|
||||
class AnswerView(ui.View):
|
||||
"""View with button to open answer modal."""
|
||||
|
||||
def __init__(self, cog: "Verification", timeout: float = 600) -> None:
|
||||
super().__init__(timeout=timeout)
|
||||
self.cog = cog
|
||||
self.add_item(AnswerButton())
|
||||
|
||||
|
||||
class Verification(commands.Cog):
|
||||
"""Member verification system."""
|
||||
|
||||
def __init__(self, bot: GuardDen) -> None:
|
||||
self.bot = bot
|
||||
self.service = VerificationService()
|
||||
self.cleanup_task.start()
|
||||
|
||||
def cog_unload(self) -> None:
|
||||
self.cleanup_task.cancel()
|
||||
|
||||
@tasks.loop(minutes=5)
|
||||
async def cleanup_task(self) -> None:
|
||||
"""Periodically clean up expired verifications."""
|
||||
count = self.service.cleanup_expired()
|
||||
if count > 0:
|
||||
logger.debug(f"Cleaned up {count} expired verifications")
|
||||
|
||||
@cleanup_task.before_loop
|
||||
async def before_cleanup(self) -> None:
|
||||
await self.bot.wait_until_ready()
|
||||
|
||||
async def complete_verification(
|
||||
self, guild_id: int, user_id: int, response: str
|
||||
) -> tuple[bool, str]:
|
||||
"""Complete a verification and assign role if successful."""
|
||||
success, message = self.service.verify(guild_id, user_id, response)
|
||||
|
||||
if success:
|
||||
# Assign verified role
|
||||
guild = self.bot.get_guild(guild_id)
|
||||
if guild:
|
||||
member = guild.get_member(user_id)
|
||||
config = await self.bot.guild_config.get_config(guild_id)
|
||||
|
||||
if member and config and config.verified_role_id:
|
||||
role = guild.get_role(config.verified_role_id)
|
||||
if role:
|
||||
try:
|
||||
await member.add_roles(role, reason="Verification completed")
|
||||
logger.info(f"Verified {member} in {guild.name}")
|
||||
except discord.Forbidden:
|
||||
logger.warning(f"Cannot assign verified role in {guild.name}")
|
||||
|
||||
return success, message
|
||||
|
||||
async def send_verification(
|
||||
self,
|
||||
member: discord.Member,
|
||||
channel: discord.TextChannel,
|
||||
challenge_type: ChallengeType,
|
||||
) -> None:
|
||||
"""Send a verification challenge to a member."""
|
||||
pending = self.service.create_challenge(
|
||||
user_id=member.id,
|
||||
guild_id=member.guild.id,
|
||||
challenge_type=challenge_type,
|
||||
)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Verification Required",
|
||||
description=pending.challenge.question,
|
||||
color=discord.Color.blue(),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
embed.set_footer(
|
||||
text=f"Expires in 10 minutes • {pending.challenge.max_attempts} attempts allowed"
|
||||
)
|
||||
|
||||
# Create appropriate view based on challenge type
|
||||
if challenge_type == ChallengeType.BUTTON:
|
||||
view = VerificationView(self)
|
||||
elif challenge_type == ChallengeType.EMOJI:
|
||||
view = EmojiVerificationView(self, pending.challenge.options)
|
||||
else:
|
||||
# Captcha or Math - use modal
|
||||
view = AnswerView(self)
|
||||
|
||||
try:
|
||||
# Try to DM the user first
|
||||
dm_channel = await member.create_dm()
|
||||
msg = await dm_channel.send(embed=embed, view=view)
|
||||
pending.message_id = msg.id
|
||||
pending.channel_id = dm_channel.id
|
||||
except discord.Forbidden:
|
||||
# Fall back to channel mention
|
||||
msg = await channel.send(
|
||||
content=member.mention,
|
||||
embed=embed,
|
||||
view=view,
|
||||
)
|
||||
pending.message_id = msg.id
|
||||
pending.channel_id = channel.id
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_member_join(self, member: discord.Member) -> None:
|
||||
"""Handle new member joins for verification."""
|
||||
if member.bot:
|
||||
return
|
||||
|
||||
config = await self.bot.guild_config.get_config(member.guild.id)
|
||||
if not config or not config.verification_enabled:
|
||||
return
|
||||
|
||||
# Determine verification channel
|
||||
channel_id = config.welcome_channel_id or config.log_channel_id
|
||||
if not channel_id:
|
||||
return
|
||||
|
||||
channel = member.guild.get_channel(channel_id)
|
||||
if not channel or not isinstance(channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
# Get challenge type from config
|
||||
try:
|
||||
challenge_type = ChallengeType(config.verification_type)
|
||||
except ValueError:
|
||||
challenge_type = ChallengeType.BUTTON
|
||||
|
||||
await self.send_verification(member, channel, challenge_type)
|
||||
|
||||
@commands.group(name="verify", invoke_without_command=True)
|
||||
@commands.guild_only()
|
||||
async def verify_cmd(self, ctx: commands.Context) -> None:
|
||||
"""Request a verification challenge."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
|
||||
if not config or not config.verification_enabled:
|
||||
await ctx.send("Verification is not enabled on this server.")
|
||||
return
|
||||
|
||||
# Check if already verified
|
||||
if config.verified_role_id:
|
||||
role = ctx.guild.get_role(config.verified_role_id)
|
||||
if role and role in ctx.author.roles:
|
||||
await ctx.send("You are already verified!")
|
||||
return
|
||||
|
||||
# Check for existing pending verification
|
||||
pending = self.service.get_pending(ctx.guild.id, ctx.author.id)
|
||||
if pending and not pending.challenge.is_expired:
|
||||
await ctx.send("You already have a pending verification. Please complete it first.")
|
||||
return
|
||||
|
||||
# Get challenge type
|
||||
try:
|
||||
challenge_type = ChallengeType(config.verification_type)
|
||||
except ValueError:
|
||||
challenge_type = ChallengeType.BUTTON
|
||||
|
||||
await self.send_verification(ctx.author, ctx.channel, challenge_type)
|
||||
await ctx.message.delete(delay=1)
|
||||
|
||||
@verify_cmd.command(name="setup")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def verify_setup(self, ctx: commands.Context) -> None:
|
||||
"""View verification setup status."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
|
||||
embed = discord.Embed(
|
||||
title="Verification Setup",
|
||||
color=discord.Color.blue(),
|
||||
)
|
||||
|
||||
embed.add_field(
|
||||
name="Enabled",
|
||||
value="✅ Yes" if config and config.verification_enabled else "❌ No",
|
||||
inline=True,
|
||||
)
|
||||
embed.add_field(
|
||||
name="Type",
|
||||
value=config.verification_type if config else "button",
|
||||
inline=True,
|
||||
)
|
||||
|
||||
if config and config.verified_role_id:
|
||||
role = ctx.guild.get_role(config.verified_role_id)
|
||||
embed.add_field(
|
||||
name="Verified Role",
|
||||
value=role.mention if role else "Not found",
|
||||
inline=True,
|
||||
)
|
||||
else:
|
||||
embed.add_field(name="Verified Role", value="Not set", inline=True)
|
||||
|
||||
pending_count = self.service.get_pending_count(ctx.guild.id)
|
||||
embed.add_field(name="Pending Verifications", value=str(pending_count), inline=True)
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@verify_cmd.command(name="enable")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def verify_enable(self, ctx: commands.Context) -> None:
|
||||
"""Enable verification for new members."""
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
|
||||
if not config or not config.verified_role_id:
|
||||
await ctx.send("Please set a verified role first with `!verify role @role`")
|
||||
return
|
||||
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, verification_enabled=True)
|
||||
await ctx.send("✅ Verification enabled for new members.")
|
||||
|
||||
@verify_cmd.command(name="disable")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def verify_disable(self, ctx: commands.Context) -> None:
|
||||
"""Disable verification."""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, verification_enabled=False)
|
||||
await ctx.send("❌ Verification disabled.")
|
||||
|
||||
@verify_cmd.command(name="role")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def verify_role(self, ctx: commands.Context, role: discord.Role) -> None:
|
||||
"""Set the role given upon verification."""
|
||||
await self.bot.guild_config.update_settings(ctx.guild.id, verified_role_id=role.id)
|
||||
await ctx.send(f"Verified role set to {role.mention}")
|
||||
|
||||
@verify_cmd.command(name="type")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def verify_type(self, ctx: commands.Context, vtype: str) -> None:
|
||||
"""Set verification type (button, captcha, math, emoji)."""
|
||||
try:
|
||||
challenge_type = ChallengeType(vtype.lower())
|
||||
except ValueError:
|
||||
valid = ", ".join(t.value for t in ChallengeType if t != ChallengeType.QUESTIONS)
|
||||
await ctx.send(f"Invalid type. Valid options: {valid}")
|
||||
return
|
||||
|
||||
await self.bot.guild_config.update_settings(
|
||||
ctx.guild.id, verification_type=challenge_type.value
|
||||
)
|
||||
await ctx.send(f"Verification type set to **{challenge_type.value}**")
|
||||
|
||||
@verify_cmd.command(name="test")
|
||||
@commands.has_permissions(administrator=True)
|
||||
@commands.guild_only()
|
||||
async def verify_test(self, ctx: commands.Context, vtype: str = "button") -> None:
|
||||
"""Test verification (sends challenge to you)."""
|
||||
try:
|
||||
challenge_type = ChallengeType(vtype.lower())
|
||||
except ValueError:
|
||||
challenge_type = ChallengeType.BUTTON
|
||||
|
||||
await self.send_verification(ctx.author, ctx.channel, challenge_type)
|
||||
|
||||
@verify_cmd.command(name="reset")
|
||||
@commands.has_permissions(kick_members=True)
|
||||
@commands.guild_only()
|
||||
async def verify_reset(self, ctx: commands.Context, member: discord.Member) -> None:
|
||||
"""Reset verification for a member (remove role and cancel pending)."""
|
||||
# Cancel any pending verification
|
||||
self.service.cancel(ctx.guild.id, member.id)
|
||||
|
||||
# Remove verified role
|
||||
config = await self.bot.guild_config.get_config(ctx.guild.id)
|
||||
if config and config.verified_role_id:
|
||||
role = ctx.guild.get_role(config.verified_role_id)
|
||||
if role and role in member.roles:
|
||||
try:
|
||||
await member.remove_roles(role, reason=f"Verification reset by {ctx.author}")
|
||||
except discord.Forbidden:
|
||||
pass
|
||||
|
||||
await ctx.send(f"Reset verification for {member.mention}")
|
||||
|
||||
|
||||
async def setup(bot: GuardDen) -> None:
|
||||
"""Load the Verification cog."""
|
||||
await bot.add_cog(Verification(bot))
|
||||
50
src/guardden/config.py
Normal file
50
src/guardden/config.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Configuration management for GuardDen."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, SecretStr
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
env_prefix="GUARDDEN_",
|
||||
)
|
||||
|
||||
# Discord settings
|
||||
discord_token: SecretStr = Field(..., description="Discord bot token")
|
||||
discord_prefix: str = Field(default="!", description="Default command prefix")
|
||||
|
||||
# Database settings
|
||||
database_url: SecretStr = Field(
|
||||
default=SecretStr("postgresql://guardden:guardden@localhost:5432/guardden"),
|
||||
description="PostgreSQL connection URL",
|
||||
)
|
||||
database_pool_min: int = Field(default=5, description="Minimum database pool size")
|
||||
database_pool_max: int = Field(default=20, description="Maximum database pool size")
|
||||
|
||||
# AI settings (optional)
|
||||
ai_provider: Literal["anthropic", "openai", "none"] = Field(
|
||||
default="none", description="AI provider for content moderation"
|
||||
)
|
||||
anthropic_api_key: SecretStr | None = Field(default=None, description="Anthropic API key")
|
||||
openai_api_key: SecretStr | None = Field(default=None, description="OpenAI API key")
|
||||
|
||||
# Logging
|
||||
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field(
|
||||
default="INFO", description="Logging level"
|
||||
)
|
||||
|
||||
# Paths
|
||||
data_dir: Path = Field(default=Path("data"), description="Data directory for persistent files")
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""Get application settings instance."""
|
||||
return Settings()
|
||||
15
src/guardden/models/__init__.py
Normal file
15
src/guardden/models/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Database models for GuardDen."""
|
||||
|
||||
from guardden.models.base import Base
|
||||
from guardden.models.guild import BannedWord, Guild, GuildSettings
|
||||
from guardden.models.moderation import ModerationLog, Strike, UserNote
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"Guild",
|
||||
"GuildSettings",
|
||||
"BannedWord",
|
||||
"ModerationLog",
|
||||
"Strike",
|
||||
"UserNote",
|
||||
]
|
||||
32
src/guardden/models/base.py
Normal file
32
src/guardden/models/base.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Base model and database utilities."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import BigInteger, DateTime, func
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all database models."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Mixin that adds created_at and updated_at timestamps."""
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
# Type alias for Discord snowflake IDs (64-bit integers)
|
||||
SnowflakeID = BigInteger
|
||||
117
src/guardden/models/guild.py
Normal file
117
src/guardden/models/guild.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Guild-related database models."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import Boolean, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from guardden.models.base import Base, SnowflakeID, TimestampMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from guardden.models.moderation import ModerationLog, Strike
|
||||
|
||||
|
||||
class Guild(Base, TimestampMixin):
|
||||
"""Represents a Discord guild (server) configuration."""
|
||||
|
||||
__tablename__ = "guilds"
|
||||
|
||||
id: Mapped[int] = mapped_column(SnowflakeID, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
owner_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
premium: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Relationships
|
||||
settings: Mapped["GuildSettings"] = relationship(
|
||||
back_populates="guild", uselist=False, cascade="all, delete-orphan"
|
||||
)
|
||||
banned_words: Mapped[list["BannedWord"]] = relationship(
|
||||
back_populates="guild", cascade="all, delete-orphan"
|
||||
)
|
||||
moderation_logs: Mapped[list["ModerationLog"]] = relationship(
|
||||
back_populates="guild", cascade="all, delete-orphan"
|
||||
)
|
||||
strikes: Mapped[list["Strike"]] = relationship(
|
||||
back_populates="guild", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class GuildSettings(Base, TimestampMixin):
|
||||
"""Per-guild bot settings and configuration."""
|
||||
|
||||
__tablename__ = "guild_settings"
|
||||
|
||||
guild_id: Mapped[int] = mapped_column(
|
||||
SnowflakeID, ForeignKey("guilds.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
|
||||
# General settings
|
||||
prefix: Mapped[str] = mapped_column(String(10), default="!", nullable=False)
|
||||
locale: Mapped[str] = mapped_column(String(10), default="en", nullable=False)
|
||||
|
||||
# Channel configuration (stored as snowflake IDs)
|
||||
log_channel_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||
mod_log_channel_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||
welcome_channel_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||
|
||||
# Role configuration
|
||||
mute_role_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||
verified_role_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||
mod_role_ids: Mapped[dict] = mapped_column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Moderation settings
|
||||
automod_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
anti_spam_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
link_filter_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Strike thresholds (actions at each threshold)
|
||||
strike_actions: Mapped[dict] = mapped_column(
|
||||
JSONB,
|
||||
default=lambda: {
|
||||
"1": {"action": "warn"},
|
||||
"3": {"action": "timeout", "duration": 3600},
|
||||
"5": {"action": "kick"},
|
||||
"7": {"action": "ban"},
|
||||
},
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# AI moderation settings
|
||||
ai_moderation_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
ai_sensitivity: Mapped[int] = mapped_column(Integer, default=50, nullable=False) # 0-100 scale
|
||||
nsfw_detection_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Verification settings
|
||||
verification_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
verification_type: Mapped[str] = mapped_column(
|
||||
String(20), default="button", nullable=False
|
||||
) # button, captcha, questions
|
||||
|
||||
# Relationship
|
||||
guild: Mapped["Guild"] = relationship(back_populates="settings")
|
||||
|
||||
|
||||
class BannedWord(Base, TimestampMixin):
|
||||
"""Banned words/phrases for a guild with regex support."""
|
||||
|
||||
__tablename__ = "banned_words"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
guild_id: Mapped[int] = mapped_column(
|
||||
SnowflakeID, ForeignKey("guilds.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
pattern: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
is_regex: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
action: Mapped[str] = mapped_column(
|
||||
String(20), default="delete", nullable=False
|
||||
) # delete, warn, strike
|
||||
reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Who added this and when
|
||||
added_by: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
|
||||
# Relationship
|
||||
guild: Mapped["Guild"] = relationship(back_populates="banned_words")
|
||||
101
src/guardden/models/moderation.py
Normal file
101
src/guardden/models/moderation.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Moderation-related database models."""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from guardden.models.base import Base, SnowflakeID, TimestampMixin
|
||||
from guardden.models.guild import Guild
|
||||
|
||||
|
||||
class ModAction(str, Enum):
|
||||
"""Types of moderation actions."""
|
||||
|
||||
WARN = "warn"
|
||||
TIMEOUT = "timeout"
|
||||
KICK = "kick"
|
||||
BAN = "ban"
|
||||
UNBAN = "unban"
|
||||
UNMUTE = "unmute"
|
||||
NOTE = "note"
|
||||
STRIKE = "strike"
|
||||
DELETE = "delete"
|
||||
|
||||
|
||||
class ModerationLog(Base, TimestampMixin):
|
||||
"""Log of all moderation actions taken."""
|
||||
|
||||
__tablename__ = "moderation_logs"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
guild_id: Mapped[int] = mapped_column(
|
||||
SnowflakeID, ForeignKey("guilds.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Target and moderator
|
||||
target_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
target_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
moderator_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
moderator_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
|
||||
# Action details
|
||||
action: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
duration: Mapped[int | None] = mapped_column(Integer, nullable=True) # Duration in seconds
|
||||
expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Context
|
||||
channel_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||
message_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True)
|
||||
message_content: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Was this an automatic action?
|
||||
is_automatic: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Relationship
|
||||
guild: Mapped["Guild"] = relationship(back_populates="moderation_logs")
|
||||
|
||||
|
||||
class Strike(Base, TimestampMixin):
|
||||
"""User strikes/warnings tracking."""
|
||||
|
||||
__tablename__ = "strikes"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
guild_id: Mapped[int] = mapped_column(
|
||||
SnowflakeID, ForeignKey("guilds.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
user_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
user_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
moderator_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
|
||||
reason: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
points: Mapped[int] = mapped_column(Integer, default=1, nullable=False)
|
||||
|
||||
# Strikes can expire
|
||||
expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
|
||||
# Reference to the moderation log entry
|
||||
mod_log_id: Mapped[int | None] = mapped_column(
|
||||
Integer, ForeignKey("moderation_logs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
|
||||
# Relationship
|
||||
guild: Mapped["Guild"] = relationship(back_populates="strikes")
|
||||
|
||||
|
||||
class UserNote(Base, TimestampMixin):
|
||||
"""Moderator notes on users."""
|
||||
|
||||
__tablename__ = "user_notes"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
guild_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
|
||||
user_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
moderator_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
16
src/guardden/services/__init__.py
Normal file
16
src/guardden/services/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Services for GuardDen."""
|
||||
|
||||
from guardden.services.automod import AutomodService
|
||||
from guardden.services.database import Database
|
||||
from guardden.services.ratelimit import RateLimiter, get_rate_limiter, ratelimit
|
||||
from guardden.services.verification import ChallengeType, VerificationService
|
||||
|
||||
__all__ = [
|
||||
"AutomodService",
|
||||
"ChallengeType",
|
||||
"Database",
|
||||
"RateLimiter",
|
||||
"VerificationService",
|
||||
"get_rate_limiter",
|
||||
"ratelimit",
|
||||
]
|
||||
6
src/guardden/services/ai/__init__.py
Normal file
6
src/guardden/services/ai/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""AI services for content moderation."""
|
||||
|
||||
from guardden.services.ai.base import AIProvider, ModerationResult
|
||||
from guardden.services.ai.factory import create_ai_provider
|
||||
|
||||
__all__ = ["AIProvider", "ModerationResult", "create_ai_provider"]
|
||||
261
src/guardden/services/ai/anthropic_provider.py
Normal file
261
src/guardden/services/ai/anthropic_provider.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""Anthropic Claude AI provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from guardden.services.ai.base import (
|
||||
AIProvider,
|
||||
ContentCategory,
|
||||
ImageAnalysisResult,
|
||||
ModerationResult,
|
||||
PhishingAnalysisResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Content moderation system prompt
|
||||
MODERATION_SYSTEM_PROMPT = """You are a content moderation AI for a Discord server. Analyze the given message and determine if it violates community guidelines.
|
||||
|
||||
Categories to check:
|
||||
- harassment: Personal attacks, bullying, intimidation
|
||||
- hate_speech: Discrimination, slurs, dehumanization based on identity
|
||||
- sexual: Explicit sexual content, sexual solicitation
|
||||
- violence: Threats, graphic violence, encouraging harm
|
||||
- self_harm: Suicide, self-injury content or encouragement
|
||||
- spam: Repetitive, promotional, or low-quality content
|
||||
- scam: Phishing attempts, fraudulent offers, impersonation
|
||||
- misinformation: Dangerous false information
|
||||
|
||||
Respond in this exact JSON format:
|
||||
{
|
||||
"is_flagged": true/false,
|
||||
"confidence": 0.0-1.0,
|
||||
"categories": ["category1", "category2"],
|
||||
"explanation": "Brief explanation",
|
||||
"suggested_action": "none/warn/delete/timeout/ban"
|
||||
}
|
||||
|
||||
Be balanced - flag genuinely problematic content but allow normal conversation, jokes, and mild language. Consider context."""
|
||||
|
||||
IMAGE_ANALYSIS_PROMPT = """Analyze this image for content moderation purposes. Check for:
|
||||
- NSFW content (nudity, sexual content)
|
||||
- Violence or gore
|
||||
- Disturbing or shocking content
|
||||
- Any content inappropriate for a general audience
|
||||
|
||||
Respond in this exact JSON format:
|
||||
{
|
||||
"is_nsfw": true/false,
|
||||
"is_violent": true/false,
|
||||
"is_disturbing": true/false,
|
||||
"confidence": 0.0-1.0,
|
||||
"description": "Brief description of the image",
|
||||
"categories": ["category1", "category2"]
|
||||
}
|
||||
|
||||
Be accurate but not overly sensitive - artistic nudity or mild violence in appropriate contexts may be acceptable."""
|
||||
|
||||
PHISHING_ANALYSIS_PROMPT = """Analyze this URL and message context for phishing or scam indicators.
|
||||
|
||||
Check for:
|
||||
- Domain impersonation (typosquatting, lookalike domains)
|
||||
- Urgency tactics ("act now", "limited time")
|
||||
- Requests for credentials or personal info
|
||||
- Too-good-to-be-true offers
|
||||
- Suspicious redirects or URL shorteners
|
||||
- Mismatched or hidden URLs
|
||||
|
||||
Respond in this exact JSON format:
|
||||
{
|
||||
"is_phishing": true/false,
|
||||
"confidence": 0.0-1.0,
|
||||
"risk_factors": ["factor1", "factor2"],
|
||||
"explanation": "Brief explanation"
|
||||
}"""
|
||||
|
||||
|
||||
class AnthropicProvider(AIProvider):
|
||||
"""AI provider using Anthropic's Claude API."""
|
||||
|
||||
def __init__(self, api_key: str, model: str = "claude-3-haiku-20240307") -> None:
|
||||
"""
|
||||
Initialize Anthropic provider.
|
||||
|
||||
Args:
|
||||
api_key: Anthropic API key
|
||||
model: Model to use (default: claude-3-haiku for speed/cost)
|
||||
"""
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError:
|
||||
raise ImportError("anthropic package required. Install with: pip install anthropic")
|
||||
|
||||
self.client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||
self.model = model
|
||||
logger.info(f"Initialized Anthropic provider with model: {model}")
|
||||
|
||||
async def _call_api(self, system: str, user_content: Any, max_tokens: int = 500) -> str:
|
||||
"""Make an API call to Claude."""
|
||||
try:
|
||||
message = await self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=max_tokens,
|
||||
system=system,
|
||||
messages=[{"role": "user", "content": user_content}],
|
||||
)
|
||||
return message.content[0].text
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API error: {e}")
|
||||
raise
|
||||
|
||||
def _parse_json_response(self, response: str) -> dict:
|
||||
"""Parse JSON from response, handling markdown code blocks."""
|
||||
import json
|
||||
|
||||
# Remove markdown code blocks if present
|
||||
text = response.strip()
|
||||
if text.startswith("```"):
|
||||
lines = text.split("\n")
|
||||
# Remove first and last lines (```json and ```)
|
||||
text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
||||
|
||||
return json.loads(text)
|
||||
|
||||
async def moderate_text(
|
||||
self,
|
||||
content: str,
|
||||
context: str | None = None,
|
||||
sensitivity: int = 50,
|
||||
) -> ModerationResult:
|
||||
"""Analyze text content for policy violations."""
|
||||
# Adjust prompt based on sensitivity
|
||||
sensitivity_note = ""
|
||||
if sensitivity < 30:
|
||||
sensitivity_note = "\n\nBe lenient - only flag clearly problematic content."
|
||||
elif sensitivity > 70:
|
||||
sensitivity_note = "\n\nBe strict - flag anything potentially problematic."
|
||||
|
||||
system = MODERATION_SYSTEM_PROMPT + sensitivity_note
|
||||
|
||||
user_message = f"Message to analyze:\n{content}"
|
||||
if context:
|
||||
user_message = f"Context: {context}\n\n{user_message}"
|
||||
|
||||
try:
|
||||
response = await self._call_api(system, user_message)
|
||||
data = self._parse_json_response(response)
|
||||
|
||||
categories = [
|
||||
ContentCategory(cat)
|
||||
for cat in data.get("categories", [])
|
||||
if cat in ContentCategory.__members__.values()
|
||||
]
|
||||
|
||||
return ModerationResult(
|
||||
is_flagged=data.get("is_flagged", False),
|
||||
confidence=float(data.get("confidence", 0.0)),
|
||||
categories=categories,
|
||||
explanation=data.get("explanation", ""),
|
||||
suggested_action=data.get("suggested_action", "none"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error moderating text: {e}")
|
||||
return ModerationResult(
|
||||
is_flagged=False,
|
||||
explanation=f"Error analyzing content: {str(e)}",
|
||||
)
|
||||
|
||||
async def analyze_image(
|
||||
self,
|
||||
image_url: str,
|
||||
sensitivity: int = 50,
|
||||
) -> ImageAnalysisResult:
|
||||
"""Analyze an image for NSFW or inappropriate content."""
|
||||
import base64
|
||||
|
||||
import aiohttp
|
||||
|
||||
sensitivity_note = ""
|
||||
if sensitivity < 30:
|
||||
sensitivity_note = "\n\nBe lenient - only flag explicit content."
|
||||
elif sensitivity > 70:
|
||||
sensitivity_note = "\n\nBe strict - flag suggestive content as well."
|
||||
|
||||
system = IMAGE_ANALYSIS_PROMPT + sensitivity_note
|
||||
|
||||
try:
|
||||
# Download image and convert to base64
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url) as resp:
|
||||
if resp.status != 200:
|
||||
return ImageAnalysisResult(
|
||||
description=f"Failed to download image: HTTP {resp.status}"
|
||||
)
|
||||
|
||||
content_type = resp.content_type or "image/jpeg"
|
||||
image_data = await resp.read()
|
||||
|
||||
# Check file size (max 20MB for Claude)
|
||||
if len(image_data) > 20 * 1024 * 1024:
|
||||
return ImageAnalysisResult(description="Image too large to analyze")
|
||||
|
||||
base64_image = base64.standard_b64encode(image_data).decode("utf-8")
|
||||
|
||||
# Create multimodal message
|
||||
user_content = [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": content_type,
|
||||
"data": base64_image,
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "Analyze this image for content moderation."},
|
||||
]
|
||||
|
||||
response = await self._call_api(system, user_content)
|
||||
data = self._parse_json_response(response)
|
||||
|
||||
return ImageAnalysisResult(
|
||||
is_nsfw=data.get("is_nsfw", False),
|
||||
is_violent=data.get("is_violent", False),
|
||||
is_disturbing=data.get("is_disturbing", False),
|
||||
confidence=float(data.get("confidence", 0.0)),
|
||||
description=data.get("description", ""),
|
||||
categories=data.get("categories", []),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing image: {e}")
|
||||
return ImageAnalysisResult(description=f"Error analyzing image: {str(e)}")
|
||||
|
||||
async def analyze_phishing(
|
||||
self,
|
||||
url: str,
|
||||
message_content: str | None = None,
|
||||
) -> PhishingAnalysisResult:
|
||||
"""Analyze a URL for phishing/scam indicators."""
|
||||
user_message = f"URL to analyze: {url}"
|
||||
if message_content:
|
||||
user_message += f"\n\nFull message context:\n{message_content}"
|
||||
|
||||
try:
|
||||
response = await self._call_api(PHISHING_ANALYSIS_PROMPT, user_message)
|
||||
data = self._parse_json_response(response)
|
||||
|
||||
return PhishingAnalysisResult(
|
||||
is_phishing=data.get("is_phishing", False),
|
||||
confidence=float(data.get("confidence", 0.0)),
|
||||
risk_factors=data.get("risk_factors", []),
|
||||
explanation=data.get("explanation", ""),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing phishing: {e}")
|
||||
return PhishingAnalysisResult(explanation=f"Error analyzing URL: {str(e)}")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
await self.client.close()
|
||||
149
src/guardden/services/ai/base.py
Normal file
149
src/guardden/services/ai/base.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Base classes for AI providers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class ContentCategory(str, Enum):
|
||||
"""Categories of problematic content."""
|
||||
|
||||
SAFE = "safe"
|
||||
HARASSMENT = "harassment"
|
||||
HATE_SPEECH = "hate_speech"
|
||||
SEXUAL = "sexual"
|
||||
VIOLENCE = "violence"
|
||||
SELF_HARM = "self_harm"
|
||||
SPAM = "spam"
|
||||
SCAM = "scam"
|
||||
MISINFORMATION = "misinformation"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModerationResult:
|
||||
"""Result of AI content moderation."""
|
||||
|
||||
is_flagged: bool = False
|
||||
confidence: float = 0.0 # 0.0 to 1.0
|
||||
categories: list[ContentCategory] = field(default_factory=list)
|
||||
explanation: str = ""
|
||||
suggested_action: Literal["none", "warn", "delete", "timeout", "ban"] = "none"
|
||||
|
||||
@property
|
||||
def severity(self) -> int:
|
||||
"""Get severity score 0-100 based on confidence and categories."""
|
||||
if not self.is_flagged:
|
||||
return 0
|
||||
|
||||
# Base severity from confidence
|
||||
severity = int(self.confidence * 50)
|
||||
|
||||
# Add severity based on category
|
||||
high_severity = {
|
||||
ContentCategory.HATE_SPEECH,
|
||||
ContentCategory.SELF_HARM,
|
||||
ContentCategory.SCAM,
|
||||
}
|
||||
medium_severity = {
|
||||
ContentCategory.HARASSMENT,
|
||||
ContentCategory.VIOLENCE,
|
||||
ContentCategory.SEXUAL,
|
||||
}
|
||||
|
||||
for cat in self.categories:
|
||||
if cat in high_severity:
|
||||
severity += 30
|
||||
elif cat in medium_severity:
|
||||
severity += 20
|
||||
else:
|
||||
severity += 10
|
||||
|
||||
return min(severity, 100)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageAnalysisResult:
|
||||
"""Result of AI image analysis."""
|
||||
|
||||
is_nsfw: bool = False
|
||||
is_violent: bool = False
|
||||
is_disturbing: bool = False
|
||||
confidence: float = 0.0
|
||||
description: str = ""
|
||||
categories: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhishingAnalysisResult:
|
||||
"""Result of AI phishing/scam analysis."""
|
||||
|
||||
is_phishing: bool = False
|
||||
confidence: float = 0.0
|
||||
risk_factors: list[str] = field(default_factory=list)
|
||||
explanation: str = ""
|
||||
|
||||
|
||||
class AIProvider(ABC):
|
||||
"""Abstract base class for AI providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def moderate_text(
|
||||
self,
|
||||
content: str,
|
||||
context: str | None = None,
|
||||
sensitivity: int = 50,
|
||||
) -> ModerationResult:
|
||||
"""
|
||||
Analyze text content for policy violations.
|
||||
|
||||
Args:
|
||||
content: The text to analyze
|
||||
context: Optional context about the conversation/server
|
||||
sensitivity: 0-100, higher means more strict
|
||||
|
||||
Returns:
|
||||
ModerationResult with analysis
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def analyze_image(
|
||||
self,
|
||||
image_url: str,
|
||||
sensitivity: int = 50,
|
||||
) -> ImageAnalysisResult:
|
||||
"""
|
||||
Analyze an image for NSFW or inappropriate content.
|
||||
|
||||
Args:
|
||||
image_url: URL of the image to analyze
|
||||
sensitivity: 0-100, higher means more strict
|
||||
|
||||
Returns:
|
||||
ImageAnalysisResult with analysis
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def analyze_phishing(
|
||||
self,
|
||||
url: str,
|
||||
message_content: str | None = None,
|
||||
) -> PhishingAnalysisResult:
|
||||
"""
|
||||
Analyze a URL for phishing/scam indicators.
|
||||
|
||||
Args:
|
||||
url: The URL to analyze
|
||||
message_content: Optional full message for context
|
||||
|
||||
Returns:
|
||||
PhishingAnalysisResult with analysis
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
pass
|
||||
67
src/guardden/services/ai/factory.py
Normal file
67
src/guardden/services/ai/factory.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Factory for creating AI providers."""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from guardden.services.ai.base import AIProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NullProvider(AIProvider):
|
||||
"""Null provider that does nothing (for when AI is disabled)."""
|
||||
|
||||
async def moderate_text(self, content, context=None, sensitivity=50):
|
||||
from guardden.services.ai.base import ModerationResult
|
||||
|
||||
return ModerationResult()
|
||||
|
||||
async def analyze_image(self, image_url, sensitivity=50):
|
||||
from guardden.services.ai.base import ImageAnalysisResult
|
||||
|
||||
return ImageAnalysisResult()
|
||||
|
||||
async def analyze_phishing(self, url, message_content=None):
|
||||
from guardden.services.ai.base import PhishingAnalysisResult
|
||||
|
||||
return PhishingAnalysisResult()
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
|
||||
def create_ai_provider(
|
||||
provider: Literal["anthropic", "openai", "none"],
|
||||
api_key: str | None = None,
|
||||
) -> AIProvider:
|
||||
"""
|
||||
Create an AI provider instance.
|
||||
|
||||
Args:
|
||||
provider: The provider type to create
|
||||
api_key: API key for the provider
|
||||
|
||||
Returns:
|
||||
AIProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is unknown or API key is missing
|
||||
"""
|
||||
if provider == "none":
|
||||
logger.info("AI moderation disabled")
|
||||
return NullProvider()
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(f"API key required for {provider} provider")
|
||||
|
||||
if provider == "anthropic":
|
||||
from guardden.services.ai.anthropic_provider import AnthropicProvider
|
||||
|
||||
return AnthropicProvider(api_key)
|
||||
|
||||
if provider == "openai":
|
||||
from guardden.services.ai.openai_provider import OpenAIProvider
|
||||
|
||||
return OpenAIProvider(api_key)
|
||||
|
||||
raise ValueError(f"Unknown AI provider: {provider}")
|
||||
213
src/guardden/services/ai/openai_provider.py
Normal file
213
src/guardden/services/ai/openai_provider.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""OpenAI AI provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from guardden.services.ai.base import (
|
||||
AIProvider,
|
||||
ContentCategory,
|
||||
ImageAnalysisResult,
|
||||
ModerationResult,
|
||||
PhishingAnalysisResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIProvider(AIProvider):
|
||||
"""AI provider using OpenAI's API."""
|
||||
|
||||
def __init__(self, api_key: str, model: str = "gpt-4o-mini") -> None:
|
||||
"""
|
||||
Initialize OpenAI provider.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
model: Model to use (default: gpt-4o-mini for speed/cost)
|
||||
"""
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ImportError("openai package required. Install with: pip install openai")
|
||||
|
||||
self.client = openai.AsyncOpenAI(api_key=api_key)
|
||||
self.model = model
|
||||
logger.info(f"Initialized OpenAI provider with model: {model}")
|
||||
|
||||
async def _call_api(
|
||||
self,
|
||||
system: str,
|
||||
user_content: Any,
|
||||
max_tokens: int = 500,
|
||||
) -> str:
|
||||
"""Make an API call to OpenAI."""
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
max_tokens=max_tokens,
|
||||
messages=[
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user_content},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}")
|
||||
raise
|
||||
|
||||
def _parse_json_response(self, response: str) -> dict:
|
||||
"""Parse JSON from response."""
|
||||
import json
|
||||
|
||||
return json.loads(response)
|
||||
|
||||
async def moderate_text(
|
||||
self,
|
||||
content: str,
|
||||
context: str | None = None,
|
||||
sensitivity: int = 50,
|
||||
) -> ModerationResult:
|
||||
"""Analyze text content for policy violations."""
|
||||
# First, use OpenAI's built-in moderation API for quick check
|
||||
try:
|
||||
mod_response = await self.client.moderations.create(input=content)
|
||||
results = mod_response.results[0]
|
||||
|
||||
# Map OpenAI categories to our categories
|
||||
category_mapping = {
|
||||
"harassment": ContentCategory.HARASSMENT,
|
||||
"harassment/threatening": ContentCategory.HARASSMENT,
|
||||
"hate": ContentCategory.HATE_SPEECH,
|
||||
"hate/threatening": ContentCategory.HATE_SPEECH,
|
||||
"self-harm": ContentCategory.SELF_HARM,
|
||||
"self-harm/intent": ContentCategory.SELF_HARM,
|
||||
"self-harm/instructions": ContentCategory.SELF_HARM,
|
||||
"sexual": ContentCategory.SEXUAL,
|
||||
"sexual/minors": ContentCategory.SEXUAL,
|
||||
"violence": ContentCategory.VIOLENCE,
|
||||
"violence/graphic": ContentCategory.VIOLENCE,
|
||||
}
|
||||
|
||||
flagged_categories = []
|
||||
max_score = 0.0
|
||||
|
||||
for category, score in results.category_scores.model_dump().items():
|
||||
if score > 0.5: # Threshold
|
||||
if category in category_mapping:
|
||||
flagged_categories.append(category_mapping[category])
|
||||
max_score = max(max_score, score)
|
||||
|
||||
# Adjust threshold based on sensitivity
|
||||
threshold = 0.3 + (0.4 * (100 - sensitivity) / 100) # 0.3 to 0.7
|
||||
|
||||
if results.flagged or max_score > threshold:
|
||||
return ModerationResult(
|
||||
is_flagged=True,
|
||||
confidence=max_score,
|
||||
categories=list(set(flagged_categories)),
|
||||
explanation="Content flagged by moderation API",
|
||||
suggested_action="delete" if max_score > 0.8 else "warn",
|
||||
)
|
||||
|
||||
return ModerationResult(is_flagged=False, confidence=1.0 - max_score)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error moderating text: {e}")
|
||||
return ModerationResult(
|
||||
is_flagged=False,
|
||||
explanation=f"Error analyzing content: {str(e)}",
|
||||
)
|
||||
|
||||
async def analyze_image(
|
||||
self,
|
||||
image_url: str,
|
||||
sensitivity: int = 50,
|
||||
) -> ImageAnalysisResult:
|
||||
"""Analyze an image for NSFW or inappropriate content."""
|
||||
system = """Analyze this image for content moderation. Respond in JSON format:
|
||||
{
|
||||
"is_nsfw": true/false,
|
||||
"is_violent": true/false,
|
||||
"is_disturbing": true/false,
|
||||
"confidence": 0.0-1.0,
|
||||
"description": "Brief description",
|
||||
"categories": ["category1"]
|
||||
}"""
|
||||
|
||||
sensitivity_note = ""
|
||||
if sensitivity < 30:
|
||||
sensitivity_note = " Be lenient - only flag explicit content."
|
||||
elif sensitivity > 70:
|
||||
sensitivity_note = " Be strict - flag suggestive content."
|
||||
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model="gpt-4o-mini", # Use vision-capable model
|
||||
max_tokens=500,
|
||||
messages=[
|
||||
{"role": "system", "content": system + sensitivity_note},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Analyze this image for moderation."},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
data = self._parse_json_response(response.choices[0].message.content or "{}")
|
||||
|
||||
return ImageAnalysisResult(
|
||||
is_nsfw=data.get("is_nsfw", False),
|
||||
is_violent=data.get("is_violent", False),
|
||||
is_disturbing=data.get("is_disturbing", False),
|
||||
confidence=float(data.get("confidence", 0.0)),
|
||||
description=data.get("description", ""),
|
||||
categories=data.get("categories", []),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing image: {e}")
|
||||
return ImageAnalysisResult(description=f"Error analyzing image: {str(e)}")
|
||||
|
||||
async def analyze_phishing(
|
||||
self,
|
||||
url: str,
|
||||
message_content: str | None = None,
|
||||
) -> PhishingAnalysisResult:
|
||||
"""Analyze a URL for phishing/scam indicators."""
|
||||
system = """Analyze the URL for phishing/scam indicators. Respond in JSON:
|
||||
{
|
||||
"is_phishing": true/false,
|
||||
"confidence": 0.0-1.0,
|
||||
"risk_factors": ["factor1"],
|
||||
"explanation": "Brief explanation"
|
||||
}
|
||||
|
||||
Check for: domain impersonation, urgency tactics, credential requests, too-good-to-be-true offers."""
|
||||
|
||||
user_message = f"URL: {url}"
|
||||
if message_content:
|
||||
user_message += f"\n\nMessage context: {message_content}"
|
||||
|
||||
try:
|
||||
response = await self._call_api(system, user_message)
|
||||
data = self._parse_json_response(response)
|
||||
|
||||
return PhishingAnalysisResult(
|
||||
is_phishing=data.get("is_phishing", False),
|
||||
confidence=float(data.get("confidence", 0.0)),
|
||||
risk_factors=data.get("risk_factors", []),
|
||||
explanation=data.get("explanation", ""),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing phishing: {e}")
|
||||
return PhishingAnalysisResult(explanation=f"Error analyzing URL: {str(e)}")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
await self.client.close()
|
||||
301
src/guardden/services/automod.py
Normal file
301
src/guardden/services/automod.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""Automod service for content filtering and spam detection."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import NamedTuple
|
||||
|
||||
import discord
|
||||
|
||||
from guardden.models import BannedWord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Known scam/phishing patterns
|
||||
SCAM_PATTERNS = [
|
||||
# Discord scam patterns
|
||||
r"discord(?:[-.]?(?:gift|nitro|free|claim|steam))[\w.-]*\.(?!com|gg)[a-z]{2,}",
|
||||
r"(?:free|claim|get)[-.\s]?(?:discord[-.\s]?)?nitro",
|
||||
r"(?:steam|discord)[-.\s]?community[-.\s]?(?:giveaway|gift)",
|
||||
# Generic phishing
|
||||
r"(?:verify|confirm)[-.\s]?(?:your)?[-.\s]?account",
|
||||
r"(?:suspended|locked|limited)[-.\s]?account",
|
||||
r"click[-.\s]?(?:here|this)[-.\s]?(?:to[-.\s]?)?(?:verify|claim|get)",
|
||||
# Crypto scams
|
||||
r"(?:free|claim|airdrop)[-.\s]?(?:crypto|bitcoin|eth|nft)",
|
||||
r"(?:double|2x)[-.\s]?your[-.\s]?(?:crypto|bitcoin|eth)",
|
||||
]
|
||||
|
||||
# Suspicious TLDs often used in phishing
|
||||
SUSPICIOUS_TLDS = {
|
||||
".xyz",
|
||||
".top",
|
||||
".club",
|
||||
".work",
|
||||
".click",
|
||||
".link",
|
||||
".info",
|
||||
".ru",
|
||||
".cn",
|
||||
".tk",
|
||||
".ml",
|
||||
".ga",
|
||||
".cf",
|
||||
".gq",
|
||||
}
|
||||
|
||||
# URL pattern for extraction
|
||||
URL_PATTERN = re.compile(
|
||||
r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+[^\s]*|"
|
||||
r"(?:www\.)?[-\w]+\.(?:com|org|net|io|gg|co|me|tv|xyz|top|club|work|click|link|info|ru|cn)[^\s]*",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
class SpamRecord(NamedTuple):
|
||||
"""Record of a message for spam tracking."""
|
||||
|
||||
content_hash: str
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserSpamTracker:
|
||||
"""Tracks spam behavior for a single user."""
|
||||
|
||||
messages: list[SpamRecord] = field(default_factory=list)
|
||||
mention_count: int = 0
|
||||
last_mention_time: datetime | None = None
|
||||
duplicate_count: int = 0
|
||||
last_action_time: datetime | None = None
|
||||
|
||||
def cleanup(self, max_age: timedelta = timedelta(minutes=1)) -> None:
|
||||
"""Remove old messages from tracking."""
|
||||
cutoff = datetime.now(timezone.utc) - max_age
|
||||
self.messages = [m for m in self.messages if m.timestamp > cutoff]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutomodResult:
|
||||
"""Result of automod check."""
|
||||
|
||||
should_delete: bool = False
|
||||
should_warn: bool = False
|
||||
should_strike: bool = False
|
||||
should_timeout: bool = False
|
||||
timeout_duration: int = 0 # seconds
|
||||
reason: str = ""
|
||||
matched_filter: str = ""
|
||||
|
||||
|
||||
class AutomodService:
|
||||
"""Service for automatic content moderation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Compile scam patterns
|
||||
self._scam_patterns = [re.compile(p, re.IGNORECASE) for p in SCAM_PATTERNS]
|
||||
|
||||
# Per-guild, per-user spam tracking
|
||||
# Structure: {guild_id: {user_id: UserSpamTracker}}
|
||||
self._spam_trackers: dict[int, dict[int, UserSpamTracker]] = defaultdict(
|
||||
lambda: defaultdict(UserSpamTracker)
|
||||
)
|
||||
|
||||
# Spam thresholds
|
||||
self.message_rate_limit = 5 # messages per window
|
||||
self.message_rate_window = 5 # seconds
|
||||
self.duplicate_threshold = 3 # same message count
|
||||
self.mention_limit = 5 # mentions per message
|
||||
self.mention_rate_limit = 10 # mentions per window
|
||||
self.mention_rate_window = 60 # seconds
|
||||
|
||||
def _get_content_hash(self, content: str) -> str:
|
||||
"""Get a normalized hash of message content for duplicate detection."""
|
||||
# Normalize: lowercase, remove extra spaces, remove special chars
|
||||
normalized = re.sub(r"[^\w\s]", "", content.lower())
|
||||
normalized = re.sub(r"\s+", " ", normalized).strip()
|
||||
return normalized
|
||||
|
||||
def check_banned_words(
|
||||
self, content: str, banned_words: list[BannedWord]
|
||||
) -> AutomodResult | None:
|
||||
"""Check message against banned words list."""
|
||||
content_lower = content.lower()
|
||||
|
||||
for banned in banned_words:
|
||||
matched = False
|
||||
|
||||
if banned.is_regex:
|
||||
try:
|
||||
if re.search(banned.pattern, content, re.IGNORECASE):
|
||||
matched = True
|
||||
except re.error:
|
||||
logger.warning(f"Invalid regex pattern: {banned.pattern}")
|
||||
continue
|
||||
else:
|
||||
if banned.pattern.lower() in content_lower:
|
||||
matched = True
|
||||
|
||||
if matched:
|
||||
result = AutomodResult(
|
||||
should_delete=True,
|
||||
reason=banned.reason or f"Matched banned word filter",
|
||||
matched_filter=f"banned_word:{banned.id}",
|
||||
)
|
||||
|
||||
if banned.action == "warn":
|
||||
result.should_warn = True
|
||||
elif banned.action == "strike":
|
||||
result.should_strike = True
|
||||
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
def check_scam_links(self, content: str) -> AutomodResult | None:
|
||||
"""Check message for scam/phishing patterns."""
|
||||
# Check for known scam patterns
|
||||
for pattern in self._scam_patterns:
|
||||
if pattern.search(content):
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_warn=True,
|
||||
reason="Message matched known scam/phishing pattern",
|
||||
matched_filter="scam_pattern",
|
||||
)
|
||||
|
||||
# Check URLs for suspicious TLDs
|
||||
urls = URL_PATTERN.findall(content)
|
||||
for url in urls:
|
||||
url_lower = url.lower()
|
||||
for tld in SUSPICIOUS_TLDS:
|
||||
if tld in url_lower:
|
||||
# Additional check: is it trying to impersonate a known domain?
|
||||
impersonation_keywords = [
|
||||
"discord",
|
||||
"steam",
|
||||
"nitro",
|
||||
"gift",
|
||||
"free",
|
||||
"login",
|
||||
"verify",
|
||||
]
|
||||
if any(kw in url_lower for kw in impersonation_keywords):
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_warn=True,
|
||||
reason=f"Suspicious link detected: {url[:50]}",
|
||||
matched_filter="suspicious_link",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def check_spam(
|
||||
self, message: discord.Message, anti_spam_enabled: bool = True
|
||||
) -> AutomodResult | None:
|
||||
"""Check message for spam behavior."""
|
||||
if not anti_spam_enabled:
|
||||
return None
|
||||
|
||||
guild_id = message.guild.id
|
||||
user_id = message.author.id
|
||||
tracker = self._spam_trackers[guild_id][user_id]
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Cleanup old records
|
||||
tracker.cleanup()
|
||||
|
||||
# Check message rate
|
||||
content_hash = self._get_content_hash(message.content)
|
||||
tracker.messages.append(SpamRecord(content_hash, now))
|
||||
|
||||
# Rate limit check
|
||||
recent_window = now - timedelta(seconds=self.message_rate_window)
|
||||
recent_messages = [m for m in tracker.messages if m.timestamp > recent_window]
|
||||
|
||||
if len(recent_messages) > self.message_rate_limit:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_timeout=True,
|
||||
timeout_duration=60, # 1 minute timeout
|
||||
reason=f"Sending messages too fast ({len(recent_messages)} in {self.message_rate_window}s)",
|
||||
matched_filter="rate_limit",
|
||||
)
|
||||
|
||||
# Duplicate message check
|
||||
duplicate_count = sum(1 for m in tracker.messages if m.content_hash == content_hash)
|
||||
if duplicate_count >= self.duplicate_threshold:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_warn=True,
|
||||
reason=f"Duplicate message detected ({duplicate_count} times)",
|
||||
matched_filter="duplicate",
|
||||
)
|
||||
|
||||
# Mass mention check
|
||||
mention_count = len(message.mentions) + len(message.role_mentions)
|
||||
if message.mention_everyone:
|
||||
mention_count += 100 # Treat @everyone as many mentions
|
||||
|
||||
if mention_count > self.mention_limit:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
should_timeout=True,
|
||||
timeout_duration=300, # 5 minute timeout
|
||||
reason=f"Mass mentions detected ({mention_count} mentions)",
|
||||
matched_filter="mass_mention",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def check_invite_links(self, content: str, allow_invites: bool = True) -> AutomodResult | None:
|
||||
"""Check for Discord invite links."""
|
||||
if allow_invites:
|
||||
return None
|
||||
|
||||
invite_pattern = re.compile(
|
||||
r"(?:https?://)?(?:www\.)?(?:discord\.(?:gg|io|me|li)|discordapp\.com/invite)/[\w-]+",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
if invite_pattern.search(content):
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
reason="Discord invite links are not allowed",
|
||||
matched_filter="invite_link",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def check_all_caps(
|
||||
self, content: str, threshold: float = 0.7, min_length: int = 10
|
||||
) -> AutomodResult | None:
|
||||
"""Check for excessive caps usage."""
|
||||
# Only check messages with enough letters
|
||||
letters = [c for c in content if c.isalpha()]
|
||||
if len(letters) < min_length:
|
||||
return None
|
||||
|
||||
caps_count = sum(1 for c in letters if c.isupper())
|
||||
caps_ratio = caps_count / len(letters)
|
||||
|
||||
if caps_ratio > threshold:
|
||||
return AutomodResult(
|
||||
should_delete=True,
|
||||
reason="Excessive caps usage",
|
||||
matched_filter="caps",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def reset_user_tracker(self, guild_id: int, user_id: int) -> None:
|
||||
"""Reset spam tracking for a user."""
|
||||
if guild_id in self._spam_trackers:
|
||||
self._spam_trackers[guild_id].pop(user_id, None)
|
||||
|
||||
def cleanup_guild(self, guild_id: int) -> None:
|
||||
"""Remove all tracking data for a guild."""
|
||||
self._spam_trackers.pop(guild_id, None)
|
||||
99
src/guardden/services/database.py
Normal file
99
src/guardden/services/database.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Database connection and session management."""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import asyncpg
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from guardden.config import Settings
|
||||
from guardden.models import Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Database:
|
||||
"""Manages database connections and sessions."""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
self.settings = settings
|
||||
self._engine = None
|
||||
self._session_factory = None
|
||||
self._pool: asyncpg.Pool | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize database connection pool."""
|
||||
db_url = self.settings.database_url.get_secret_value()
|
||||
|
||||
# Create SQLAlchemy async engine
|
||||
# Convert postgresql:// to postgresql+asyncpg://
|
||||
if db_url.startswith("postgresql://"):
|
||||
sqlalchemy_url = db_url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||
else:
|
||||
sqlalchemy_url = db_url
|
||||
|
||||
self._engine = create_async_engine(
|
||||
sqlalchemy_url,
|
||||
pool_size=self.settings.database_pool_min,
|
||||
max_overflow=self.settings.database_pool_max - self.settings.database_pool_min,
|
||||
echo=self.settings.log_level == "DEBUG",
|
||||
)
|
||||
|
||||
self._session_factory = async_sessionmaker(
|
||||
self._engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
# Also create a raw asyncpg pool for performance-critical operations
|
||||
self._pool = await asyncpg.create_pool(
|
||||
db_url,
|
||||
min_size=self.settings.database_pool_min,
|
||||
max_size=self.settings.database_pool_max,
|
||||
)
|
||||
|
||||
logger.info("Database connection established")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Close all database connections."""
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
|
||||
if self._engine:
|
||||
await self._engine.dispose()
|
||||
self._engine = None
|
||||
|
||||
logger.info("Database connections closed")
|
||||
|
||||
async def create_tables(self) -> None:
|
||||
"""Create all database tables."""
|
||||
if not self._engine:
|
||||
raise RuntimeError("Database not connected")
|
||||
|
||||
async with self._engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
logger.info("Database tables created")
|
||||
|
||||
@asynccontextmanager
|
||||
async def session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Get a database session context manager."""
|
||||
if not self._session_factory:
|
||||
raise RuntimeError("Database not connected")
|
||||
|
||||
async with self._session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
@property
|
||||
def pool(self) -> asyncpg.Pool:
|
||||
"""Get the raw asyncpg connection pool."""
|
||||
if not self._pool:
|
||||
raise RuntimeError("Database not connected")
|
||||
return self._pool
|
||||
145
src/guardden/services/guild_config.py
Normal file
145
src/guardden/services/guild_config.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Guild configuration service."""
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
|
||||
import discord
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from guardden.models import BannedWord, Guild, GuildSettings
|
||||
from guardden.services.database import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GuildConfigService:
|
||||
"""Manages guild configurations with caching."""
|
||||
|
||||
def __init__(self, database: Database) -> None:
|
||||
self.database = database
|
||||
self._cache: dict[int, GuildSettings] = {}
|
||||
|
||||
async def get_config(self, guild_id: int) -> GuildSettings | None:
|
||||
"""Get guild configuration, using cache if available."""
|
||||
if guild_id in self._cache:
|
||||
return self._cache[guild_id]
|
||||
|
||||
async with self.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(GuildSettings).where(GuildSettings.guild_id == guild_id)
|
||||
)
|
||||
settings = result.scalar_one_or_none()
|
||||
|
||||
if settings:
|
||||
self._cache[guild_id] = settings
|
||||
|
||||
return settings
|
||||
|
||||
async def get_guild(self, guild_id: int) -> Guild | None:
|
||||
"""Get full guild data including settings and banned words."""
|
||||
async with self.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(Guild)
|
||||
.options(selectinload(Guild.settings), selectinload(Guild.banned_words))
|
||||
.where(Guild.id == guild_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create_guild(self, guild: discord.Guild) -> Guild:
|
||||
"""Create a new guild entry with default settings."""
|
||||
async with self.database.session() as session:
|
||||
# Check if guild already exists
|
||||
existing = await session.get(Guild, guild.id)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
# Create new guild
|
||||
db_guild = Guild(
|
||||
id=guild.id,
|
||||
name=guild.name,
|
||||
owner_id=guild.owner_id,
|
||||
)
|
||||
session.add(db_guild)
|
||||
await session.flush()
|
||||
|
||||
# Create default settings
|
||||
settings = GuildSettings(guild_id=guild.id)
|
||||
session.add(settings)
|
||||
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"Created guild config for {guild.name} (ID: {guild.id})")
|
||||
return db_guild
|
||||
|
||||
async def update_settings(self, guild_id: int, **kwargs) -> GuildSettings | None:
|
||||
"""Update guild settings."""
|
||||
async with self.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(GuildSettings).where(GuildSettings.guild_id == guild_id)
|
||||
)
|
||||
settings = result.scalar_one_or_none()
|
||||
|
||||
if not settings:
|
||||
return None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(settings, key):
|
||||
setattr(settings, key, value)
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Invalidate cache
|
||||
self._cache.pop(guild_id, None)
|
||||
|
||||
return settings
|
||||
|
||||
def invalidate_cache(self, guild_id: int) -> None:
|
||||
"""Remove a guild from the cache."""
|
||||
self._cache.pop(guild_id, None)
|
||||
|
||||
async def get_banned_words(self, guild_id: int) -> list[BannedWord]:
|
||||
"""Get all banned words for a guild."""
|
||||
async with self.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(BannedWord).where(BannedWord.guild_id == guild_id)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def add_banned_word(
|
||||
self,
|
||||
guild_id: int,
|
||||
pattern: str,
|
||||
added_by: int,
|
||||
is_regex: bool = False,
|
||||
action: str = "delete",
|
||||
reason: str | None = None,
|
||||
) -> BannedWord:
|
||||
"""Add a banned word to a guild."""
|
||||
async with self.database.session() as session:
|
||||
banned_word = BannedWord(
|
||||
guild_id=guild_id,
|
||||
pattern=pattern,
|
||||
is_regex=is_regex,
|
||||
action=action,
|
||||
reason=reason,
|
||||
added_by=added_by,
|
||||
)
|
||||
session.add(banned_word)
|
||||
await session.commit()
|
||||
return banned_word
|
||||
|
||||
async def remove_banned_word(self, guild_id: int, word_id: int) -> bool:
|
||||
"""Remove a banned word from a guild."""
|
||||
async with self.database.session() as session:
|
||||
result = await session.execute(
|
||||
select(BannedWord).where(BannedWord.id == word_id, BannedWord.guild_id == guild_id)
|
||||
)
|
||||
word = result.scalar_one_or_none()
|
||||
|
||||
if word:
|
||||
session.delete(word)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
return False
|
||||
300
src/guardden/services/ratelimit.py
Normal file
300
src/guardden/services/ratelimit.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""Rate limiting service for command and action throttling."""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimitScope(str, Enum):
|
||||
"""Scope of rate limiting."""
|
||||
|
||||
USER = "user" # Per user globally
|
||||
MEMBER = "member" # Per user per guild
|
||||
CHANNEL = "channel" # Per channel
|
||||
GUILD = "guild" # Per guild
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitBucket:
|
||||
"""Tracks rate limit state for a single bucket."""
|
||||
|
||||
max_requests: int
|
||||
window_seconds: float
|
||||
requests: list[datetime] = field(default_factory=list)
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Remove expired requests from tracking."""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(seconds=self.window_seconds)
|
||||
self.requests = [r for r in self.requests if r > cutoff]
|
||||
|
||||
def is_limited(self) -> bool:
|
||||
"""Check if this bucket is rate limited."""
|
||||
self.cleanup()
|
||||
return len(self.requests) >= self.max_requests
|
||||
|
||||
def record(self) -> None:
|
||||
"""Record a request."""
|
||||
self.requests.append(datetime.now(timezone.utc))
|
||||
|
||||
def remaining(self) -> int:
|
||||
"""Get remaining requests in current window."""
|
||||
self.cleanup()
|
||||
return max(0, self.max_requests - len(self.requests))
|
||||
|
||||
def reset_after(self) -> float:
|
||||
"""Get seconds until rate limit resets."""
|
||||
if not self.requests:
|
||||
return 0.0
|
||||
self.cleanup()
|
||||
if not self.requests:
|
||||
return 0.0
|
||||
oldest = min(self.requests)
|
||||
reset_time = oldest + timedelta(seconds=self.window_seconds)
|
||||
remaining = (reset_time - datetime.now(timezone.utc)).total_seconds()
|
||||
return max(0.0, remaining)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitConfig:
|
||||
"""Configuration for a rate limit."""
|
||||
|
||||
max_requests: int
|
||||
window_seconds: float
|
||||
scope: RateLimitScope = RateLimitScope.MEMBER
|
||||
|
||||
def create_bucket(self) -> RateLimitBucket:
|
||||
return RateLimitBucket(
|
||||
max_requests=self.max_requests,
|
||||
window_seconds=self.window_seconds,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitResult:
|
||||
"""Result of a rate limit check."""
|
||||
|
||||
is_limited: bool
|
||||
remaining: int
|
||||
reset_after: float
|
||||
bucket_key: str
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""General-purpose rate limiter."""
|
||||
|
||||
# Default rate limits for various actions
|
||||
DEFAULT_LIMITS = {
|
||||
"command": RateLimitConfig(5, 10, RateLimitScope.MEMBER), # 5 commands per 10s
|
||||
"moderation": RateLimitConfig(10, 60, RateLimitScope.MEMBER), # 10 mod actions per minute
|
||||
"verification": RateLimitConfig(3, 300, RateLimitScope.MEMBER), # 3 verifications per 5 min
|
||||
"message": RateLimitConfig(10, 10, RateLimitScope.MEMBER), # 10 messages per 10s
|
||||
"api_call": RateLimitConfig(
|
||||
30, 60, RateLimitScope.GUILD
|
||||
), # 30 API calls per minute per guild
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Buckets: {action: {bucket_key: RateLimitBucket}}
|
||||
self._buckets: dict[str, dict[str, RateLimitBucket]] = defaultdict(dict)
|
||||
self._configs: dict[str, RateLimitConfig] = dict(self.DEFAULT_LIMITS)
|
||||
|
||||
def configure(self, action: str, config: RateLimitConfig) -> None:
|
||||
"""Configure rate limit for an action."""
|
||||
self._configs[action] = config
|
||||
# Clear existing buckets for this action
|
||||
self._buckets[action].clear()
|
||||
|
||||
def _get_bucket_key(
|
||||
self,
|
||||
scope: RateLimitScope,
|
||||
user_id: int | None = None,
|
||||
guild_id: int | None = None,
|
||||
channel_id: int | None = None,
|
||||
) -> str:
|
||||
"""Generate a bucket key based on scope."""
|
||||
if scope == RateLimitScope.USER:
|
||||
return f"user:{user_id}"
|
||||
elif scope == RateLimitScope.MEMBER:
|
||||
return f"member:{guild_id}:{user_id}"
|
||||
elif scope == RateLimitScope.CHANNEL:
|
||||
return f"channel:{channel_id}"
|
||||
elif scope == RateLimitScope.GUILD:
|
||||
return f"guild:{guild_id}"
|
||||
return f"unknown:{user_id}:{guild_id}"
|
||||
|
||||
def check(
|
||||
self,
|
||||
action: str,
|
||||
user_id: int | None = None,
|
||||
guild_id: int | None = None,
|
||||
channel_id: int | None = None,
|
||||
) -> RateLimitResult:
|
||||
"""
|
||||
Check if an action is rate limited.
|
||||
|
||||
Does not record the request - use `acquire()` for that.
|
||||
"""
|
||||
config = self._configs.get(action)
|
||||
if not config:
|
||||
return RateLimitResult(
|
||||
is_limited=False,
|
||||
remaining=999,
|
||||
reset_after=0,
|
||||
bucket_key="",
|
||||
)
|
||||
|
||||
bucket_key = self._get_bucket_key(config.scope, user_id, guild_id, channel_id)
|
||||
bucket = self._buckets[action].get(bucket_key)
|
||||
|
||||
if not bucket:
|
||||
return RateLimitResult(
|
||||
is_limited=False,
|
||||
remaining=config.max_requests,
|
||||
reset_after=0,
|
||||
bucket_key=bucket_key,
|
||||
)
|
||||
|
||||
return RateLimitResult(
|
||||
is_limited=bucket.is_limited(),
|
||||
remaining=bucket.remaining(),
|
||||
reset_after=bucket.reset_after(),
|
||||
bucket_key=bucket_key,
|
||||
)
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
action: str,
|
||||
user_id: int | None = None,
|
||||
guild_id: int | None = None,
|
||||
channel_id: int | None = None,
|
||||
) -> RateLimitResult:
|
||||
"""
|
||||
Attempt to acquire a rate limit slot.
|
||||
|
||||
Records the request if not limited.
|
||||
"""
|
||||
config = self._configs.get(action)
|
||||
if not config:
|
||||
return RateLimitResult(
|
||||
is_limited=False,
|
||||
remaining=999,
|
||||
reset_after=0,
|
||||
bucket_key="",
|
||||
)
|
||||
|
||||
bucket_key = self._get_bucket_key(config.scope, user_id, guild_id, channel_id)
|
||||
|
||||
if bucket_key not in self._buckets[action]:
|
||||
self._buckets[action][bucket_key] = config.create_bucket()
|
||||
|
||||
bucket = self._buckets[action][bucket_key]
|
||||
|
||||
if bucket.is_limited():
|
||||
return RateLimitResult(
|
||||
is_limited=True,
|
||||
remaining=0,
|
||||
reset_after=bucket.reset_after(),
|
||||
bucket_key=bucket_key,
|
||||
)
|
||||
|
||||
bucket.record()
|
||||
|
||||
return RateLimitResult(
|
||||
is_limited=False,
|
||||
remaining=bucket.remaining(),
|
||||
reset_after=bucket.reset_after(),
|
||||
bucket_key=bucket_key,
|
||||
)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
action: str,
|
||||
user_id: int | None = None,
|
||||
guild_id: int | None = None,
|
||||
channel_id: int | None = None,
|
||||
) -> bool:
|
||||
"""Reset rate limit for a specific bucket."""
|
||||
config = self._configs.get(action)
|
||||
if not config:
|
||||
return False
|
||||
|
||||
bucket_key = self._get_bucket_key(config.scope, user_id, guild_id, channel_id)
|
||||
return self._buckets[action].pop(bucket_key, None) is not None
|
||||
|
||||
def cleanup(self) -> int:
|
||||
"""Clean up empty and expired buckets. Returns count removed."""
|
||||
removed = 0
|
||||
for action in list(self._buckets.keys()):
|
||||
for key in list(self._buckets[action].keys()):
|
||||
bucket = self._buckets[action][key]
|
||||
bucket.cleanup()
|
||||
if not bucket.requests:
|
||||
del self._buckets[action][key]
|
||||
removed += 1
|
||||
return removed
|
||||
|
||||
|
||||
# Global rate limiter instance
|
||||
_rate_limiter: RateLimiter | None = None
|
||||
|
||||
|
||||
def get_rate_limiter() -> RateLimiter:
|
||||
"""Get or create the global rate limiter instance."""
|
||||
global _rate_limiter
|
||||
if _rate_limiter is None:
|
||||
_rate_limiter = RateLimiter()
|
||||
return _rate_limiter
|
||||
|
||||
|
||||
def ratelimit(
|
||||
action: str = "command",
|
||||
max_requests: int | None = None,
|
||||
window_seconds: float | None = None,
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator for rate limiting commands.
|
||||
|
||||
Usage:
|
||||
@ratelimit("moderation", max_requests=5, window_seconds=60)
|
||||
async def my_command(self, ctx):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
async def wrapper(self, ctx, *args, **kwargs):
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
# Configure if custom limits provided
|
||||
if max_requests is not None and window_seconds is not None:
|
||||
limiter.configure(
|
||||
action,
|
||||
RateLimitConfig(max_requests, window_seconds, RateLimitScope.MEMBER),
|
||||
)
|
||||
|
||||
result = limiter.acquire(
|
||||
action,
|
||||
user_id=ctx.author.id,
|
||||
guild_id=ctx.guild.id if ctx.guild else None,
|
||||
channel_id=ctx.channel.id,
|
||||
)
|
||||
|
||||
if result.is_limited:
|
||||
await ctx.send(
|
||||
f"You're being rate limited. Try again in {result.reset_after:.1f} seconds.",
|
||||
delete_after=5,
|
||||
)
|
||||
return
|
||||
|
||||
return await func(self, ctx, *args, **kwargs)
|
||||
|
||||
# Preserve function metadata
|
||||
wrapper.__name__ = func.__name__
|
||||
wrapper.__doc__ = func.__doc__
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
300
src/guardden/services/verification.py
Normal file
300
src/guardden/services/verification.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""Verification service for new member challenges."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import string
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import discord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChallengeType(str, Enum):
|
||||
"""Types of verification challenges."""
|
||||
|
||||
BUTTON = "button" # Simple button click
|
||||
CAPTCHA = "captcha" # Text-based captcha
|
||||
MATH = "math" # Simple math problem
|
||||
EMOJI = "emoji" # Select correct emoji
|
||||
QUESTIONS = "questions" # Custom questions
|
||||
|
||||
|
||||
@dataclass
|
||||
class Challenge:
|
||||
"""Represents a verification challenge."""
|
||||
|
||||
challenge_type: ChallengeType
|
||||
question: str
|
||||
answer: str
|
||||
options: list[str] = field(default_factory=list) # For multiple choice
|
||||
expires_at: datetime = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc) + timedelta(minutes=10)
|
||||
)
|
||||
attempts: int = 0
|
||||
max_attempts: int = 3
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
return datetime.now(timezone.utc) > self.expires_at
|
||||
|
||||
def check_answer(self, response: str) -> bool:
|
||||
"""Check if the response is correct."""
|
||||
self.attempts += 1
|
||||
return response.strip().lower() == self.answer.lower()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingVerification:
|
||||
"""Tracks a pending verification for a user."""
|
||||
|
||||
user_id: int
|
||||
guild_id: int
|
||||
challenge: Challenge
|
||||
message_id: int | None = None
|
||||
channel_id: int | None = None
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class ChallengeGenerator(ABC):
|
||||
"""Abstract base class for challenge generators."""
|
||||
|
||||
@abstractmethod
|
||||
def generate(self) -> Challenge:
|
||||
"""Generate a new challenge."""
|
||||
pass
|
||||
|
||||
|
||||
class ButtonChallengeGenerator(ChallengeGenerator):
|
||||
"""Generates simple button click challenges."""
|
||||
|
||||
def generate(self) -> Challenge:
|
||||
return Challenge(
|
||||
challenge_type=ChallengeType.BUTTON,
|
||||
question="Click the button below to verify you're human.",
|
||||
answer="verified",
|
||||
)
|
||||
|
||||
|
||||
class CaptchaChallengeGenerator(ChallengeGenerator):
|
||||
"""Generates text-based captcha challenges."""
|
||||
|
||||
def __init__(self, length: int = 6) -> None:
|
||||
self.length = length
|
||||
|
||||
def generate(self) -> Challenge:
|
||||
# Generate random alphanumeric code (avoiding confusing chars)
|
||||
chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
||||
code = "".join(random.choices(chars, k=self.length))
|
||||
|
||||
# Create visual representation with some obfuscation
|
||||
visual = self._create_visual(code)
|
||||
|
||||
return Challenge(
|
||||
challenge_type=ChallengeType.CAPTCHA,
|
||||
question=f"Enter the code shown below:\n```\n{visual}\n```",
|
||||
answer=code,
|
||||
)
|
||||
|
||||
def _create_visual(self, code: str) -> str:
|
||||
"""Create a simple text-based visual captcha."""
|
||||
lines = []
|
||||
# Add some noise characters
|
||||
noise_chars = ".-*~^"
|
||||
|
||||
for _ in range(2):
|
||||
lines.append("".join(random.choices(noise_chars, k=len(code) * 2)))
|
||||
|
||||
# Add the code with spacing
|
||||
spaced = " ".join(code)
|
||||
lines.append(spaced)
|
||||
|
||||
for _ in range(2):
|
||||
lines.append("".join(random.choices(noise_chars, k=len(code) * 2)))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class MathChallengeGenerator(ChallengeGenerator):
|
||||
"""Generates simple math problem challenges."""
|
||||
|
||||
def generate(self) -> Challenge:
|
||||
# Generate simple addition/subtraction/multiplication
|
||||
operation = random.choice(["+", "-", "*"])
|
||||
|
||||
if operation == "*":
|
||||
a = random.randint(2, 10)
|
||||
b = random.randint(2, 10)
|
||||
else:
|
||||
a = random.randint(10, 50)
|
||||
b = random.randint(1, 20)
|
||||
|
||||
if operation == "+":
|
||||
answer = a + b
|
||||
elif operation == "-":
|
||||
# Ensure positive result
|
||||
if b > a:
|
||||
a, b = b, a
|
||||
answer = a - b
|
||||
else:
|
||||
answer = a * b
|
||||
|
||||
return Challenge(
|
||||
challenge_type=ChallengeType.MATH,
|
||||
question=f"Solve this math problem: **{a} {operation} {b} = ?**",
|
||||
answer=str(answer),
|
||||
)
|
||||
|
||||
|
||||
class EmojiChallengeGenerator(ChallengeGenerator):
|
||||
"""Generates emoji selection challenges."""
|
||||
|
||||
EMOJI_SETS = [
|
||||
("animals", ["🐶", "🐱", "🐭", "🐹", "🐰", "🦊", "🐻", "🐼"]),
|
||||
("fruits", ["🍎", "🍐", "🍊", "🍋", "🍌", "🍉", "🍇", "🍓"]),
|
||||
("weather", ["☀️", "🌙", "⭐", "🌧️", "❄️", "🌈", "⚡", "🌪️"]),
|
||||
("sports", ["⚽", "🏀", "🏈", "⚾", "🎾", "🏐", "🏉", "🎱"]),
|
||||
]
|
||||
|
||||
def generate(self) -> Challenge:
|
||||
category, emojis = random.choice(self.EMOJI_SETS)
|
||||
target = random.choice(emojis)
|
||||
|
||||
# Create options with the target and some others
|
||||
options = [target]
|
||||
other_emojis = [e for e in emojis if e != target]
|
||||
options.extend(random.sample(other_emojis, min(3, len(other_emojis))))
|
||||
random.shuffle(options)
|
||||
|
||||
return Challenge(
|
||||
challenge_type=ChallengeType.EMOJI,
|
||||
question=f"Select the {self._emoji_name(target)} emoji:",
|
||||
answer=target,
|
||||
options=options,
|
||||
)
|
||||
|
||||
def _emoji_name(self, emoji: str) -> str:
|
||||
"""Get a description of the emoji."""
|
||||
names = {
|
||||
"🐶": "dog",
|
||||
"🐱": "cat",
|
||||
"🐭": "mouse",
|
||||
"🐹": "hamster",
|
||||
"🐰": "rabbit",
|
||||
"🦊": "fox",
|
||||
"🐻": "bear",
|
||||
"🐼": "panda",
|
||||
"🍎": "apple",
|
||||
"🍐": "pear",
|
||||
"🍊": "orange",
|
||||
"🍋": "lemon",
|
||||
"🍌": "banana",
|
||||
"🍉": "watermelon",
|
||||
"🍇": "grapes",
|
||||
"🍓": "strawberry",
|
||||
"☀️": "sun",
|
||||
"🌙": "moon",
|
||||
"⭐": "star",
|
||||
"🌧️": "rain",
|
||||
"❄️": "snowflake",
|
||||
"🌈": "rainbow",
|
||||
"⚡": "lightning",
|
||||
"🌪️": "tornado",
|
||||
"⚽": "soccer ball",
|
||||
"🏀": "basketball",
|
||||
"🏈": "football",
|
||||
"⚾": "baseball",
|
||||
"🎾": "tennis",
|
||||
"🏐": "volleyball",
|
||||
"🏉": "rugby",
|
||||
"🎱": "pool ball",
|
||||
}
|
||||
return names.get(emoji, "correct")
|
||||
|
||||
|
||||
class VerificationService:
|
||||
"""Service for managing member verification."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Pending verifications: {(guild_id, user_id): PendingVerification}
|
||||
self._pending: dict[tuple[int, int], PendingVerification] = {}
|
||||
|
||||
# Challenge generators
|
||||
self._generators: dict[ChallengeType, ChallengeGenerator] = {
|
||||
ChallengeType.BUTTON: ButtonChallengeGenerator(),
|
||||
ChallengeType.CAPTCHA: CaptchaChallengeGenerator(),
|
||||
ChallengeType.MATH: MathChallengeGenerator(),
|
||||
ChallengeType.EMOJI: EmojiChallengeGenerator(),
|
||||
}
|
||||
|
||||
def create_challenge(
|
||||
self,
|
||||
user_id: int,
|
||||
guild_id: int,
|
||||
challenge_type: ChallengeType = ChallengeType.BUTTON,
|
||||
) -> PendingVerification:
|
||||
"""Create a new verification challenge for a user."""
|
||||
generator = self._generators.get(challenge_type)
|
||||
if not generator:
|
||||
generator = self._generators[ChallengeType.BUTTON]
|
||||
|
||||
challenge = generator.generate()
|
||||
pending = PendingVerification(
|
||||
user_id=user_id,
|
||||
guild_id=guild_id,
|
||||
challenge=challenge,
|
||||
)
|
||||
|
||||
self._pending[(guild_id, user_id)] = pending
|
||||
return pending
|
||||
|
||||
def get_pending(self, guild_id: int, user_id: int) -> PendingVerification | None:
|
||||
"""Get a pending verification for a user."""
|
||||
return self._pending.get((guild_id, user_id))
|
||||
|
||||
def verify(self, guild_id: int, user_id: int, response: str) -> tuple[bool, str]:
|
||||
"""
|
||||
Attempt to verify a user's response.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, message)
|
||||
"""
|
||||
pending = self._pending.get((guild_id, user_id))
|
||||
|
||||
if not pending:
|
||||
return False, "No pending verification found."
|
||||
|
||||
if pending.challenge.is_expired:
|
||||
self._pending.pop((guild_id, user_id), None)
|
||||
return False, "Verification expired. Please request a new one."
|
||||
|
||||
if pending.challenge.attempts >= pending.challenge.max_attempts:
|
||||
self._pending.pop((guild_id, user_id), None)
|
||||
return False, "Too many failed attempts. Please request a new verification."
|
||||
|
||||
if pending.challenge.check_answer(response):
|
||||
self._pending.pop((guild_id, user_id), None)
|
||||
return True, "Verification successful!"
|
||||
|
||||
remaining = pending.challenge.max_attempts - pending.challenge.attempts
|
||||
return False, f"Incorrect. {remaining} attempt(s) remaining."
|
||||
|
||||
def cancel(self, guild_id: int, user_id: int) -> bool:
|
||||
"""Cancel a pending verification."""
|
||||
return self._pending.pop((guild_id, user_id), None) is not None
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""Remove expired verifications. Returns count of removed."""
|
||||
expired = [key for key, pending in self._pending.items() if pending.challenge.is_expired]
|
||||
for key in expired:
|
||||
self._pending.pop(key, None)
|
||||
return len(expired)
|
||||
|
||||
def get_pending_count(self, guild_id: int) -> int:
|
||||
"""Get count of pending verifications for a guild."""
|
||||
return sum(1 for (gid, _) in self._pending if gid == guild_id)
|
||||
5
src/guardden/utils/__init__.py
Normal file
5
src/guardden/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Utility functions for GuardDen."""
|
||||
|
||||
from guardden.utils.logging import setup_logging
|
||||
|
||||
__all__ = ["setup_logging"]
|
||||
27
src/guardden/utils/logging.py
Normal file
27
src/guardden/utils/logging.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Logging configuration for GuardDen."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Literal
|
||||
|
||||
|
||||
def setup_logging(level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO") -> None:
|
||||
"""Configure logging for the application."""
|
||||
log_format = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
|
||||
date_format = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
# Configure root logger
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, level),
|
||||
format=log_format,
|
||||
datefmt=date_format,
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
# Reduce noise from third-party libraries
|
||||
logging.getLogger("discord").setLevel(logging.WARNING)
|
||||
logging.getLogger("discord.http").setLevel(logging.WARNING)
|
||||
logging.getLogger("asyncio").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(
|
||||
logging.DEBUG if level == "DEBUG" else logging.WARNING
|
||||
)
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for GuardDen."""
|
||||
15
tests/conftest.py
Normal file
15
tests/conftest.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Pytest fixtures for GuardDen tests."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_guild_id() -> int:
|
||||
"""Return a sample Discord guild ID."""
|
||||
return 123456789012345678
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_user_id() -> int:
|
||||
"""Return a sample Discord user ID."""
|
||||
return 987654321098765432
|
||||
119
tests/test_ai.py
Normal file
119
tests/test_ai.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Tests for AI services."""
|
||||
|
||||
import pytest
|
||||
|
||||
from guardden.services.ai.base import ContentCategory, ModerationResult
|
||||
from guardden.services.ai.factory import NullProvider, create_ai_provider
|
||||
|
||||
|
||||
class TestModerationResult:
|
||||
"""Tests for ModerationResult dataclass."""
|
||||
|
||||
def test_severity_not_flagged(self) -> None:
|
||||
"""Test severity is 0 when not flagged."""
|
||||
result = ModerationResult(is_flagged=False, confidence=0.9)
|
||||
assert result.severity == 0
|
||||
|
||||
def test_severity_with_confidence(self) -> None:
|
||||
"""Test severity includes confidence."""
|
||||
result = ModerationResult(
|
||||
is_flagged=True,
|
||||
confidence=0.8,
|
||||
categories=[],
|
||||
)
|
||||
# 0.8 * 50 = 40
|
||||
assert result.severity == 40
|
||||
|
||||
def test_severity_high_category(self) -> None:
|
||||
"""Test severity with high-severity category."""
|
||||
result = ModerationResult(
|
||||
is_flagged=True,
|
||||
confidence=0.5,
|
||||
categories=[ContentCategory.HATE_SPEECH],
|
||||
)
|
||||
# 0.5 * 50 + 30 = 55
|
||||
assert result.severity == 55
|
||||
|
||||
def test_severity_medium_category(self) -> None:
|
||||
"""Test severity with medium-severity category."""
|
||||
result = ModerationResult(
|
||||
is_flagged=True,
|
||||
confidence=0.5,
|
||||
categories=[ContentCategory.HARASSMENT],
|
||||
)
|
||||
# 0.5 * 50 + 20 = 45
|
||||
assert result.severity == 45
|
||||
|
||||
def test_severity_multiple_categories(self) -> None:
|
||||
"""Test severity with multiple categories."""
|
||||
result = ModerationResult(
|
||||
is_flagged=True,
|
||||
confidence=0.5,
|
||||
categories=[ContentCategory.HATE_SPEECH, ContentCategory.VIOLENCE],
|
||||
)
|
||||
# 0.5 * 50 + 30 + 20 = 75
|
||||
assert result.severity == 75
|
||||
|
||||
def test_severity_capped_at_100(self) -> None:
|
||||
"""Test severity is capped at 100."""
|
||||
result = ModerationResult(
|
||||
is_flagged=True,
|
||||
confidence=1.0,
|
||||
categories=[
|
||||
ContentCategory.HATE_SPEECH,
|
||||
ContentCategory.SELF_HARM,
|
||||
ContentCategory.SCAM,
|
||||
],
|
||||
)
|
||||
# Would be 50 + 30 + 30 + 30 = 140, capped to 100
|
||||
assert result.severity == 100
|
||||
|
||||
|
||||
class TestNullProvider:
|
||||
"""Tests for NullProvider."""
|
||||
|
||||
@pytest.fixture
|
||||
def provider(self) -> NullProvider:
|
||||
return NullProvider()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_moderate_text_returns_empty(self, provider: NullProvider) -> None:
|
||||
"""Test moderate_text returns unflagged result."""
|
||||
result = await provider.moderate_text("test content")
|
||||
assert result.is_flagged is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_image_returns_empty(self, provider: NullProvider) -> None:
|
||||
"""Test analyze_image returns empty result."""
|
||||
result = await provider.analyze_image("http://example.com/image.jpg")
|
||||
assert result.is_nsfw is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_phishing_returns_empty(self, provider: NullProvider) -> None:
|
||||
"""Test analyze_phishing returns empty result."""
|
||||
result = await provider.analyze_phishing("http://example.com")
|
||||
assert result.is_phishing is False
|
||||
|
||||
|
||||
class TestFactory:
|
||||
"""Tests for AI provider factory."""
|
||||
|
||||
def test_create_null_provider(self) -> None:
|
||||
"""Test creating null provider."""
|
||||
provider = create_ai_provider("none")
|
||||
assert isinstance(provider, NullProvider)
|
||||
|
||||
def test_create_anthropic_without_key(self) -> None:
|
||||
"""Test creating anthropic provider without key raises error."""
|
||||
with pytest.raises(ValueError, match="API key required"):
|
||||
create_ai_provider("anthropic", None)
|
||||
|
||||
def test_create_openai_without_key(self) -> None:
|
||||
"""Test creating openai provider without key raises error."""
|
||||
with pytest.raises(ValueError, match="API key required"):
|
||||
create_ai_provider("openai", None)
|
||||
|
||||
def test_create_unknown_provider(self) -> None:
|
||||
"""Test creating unknown provider raises error."""
|
||||
with pytest.raises(ValueError, match="Unknown AI provider"):
|
||||
create_ai_provider("unknown", "key") # type: ignore
|
||||
153
tests/test_automod.py
Normal file
153
tests/test_automod.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Tests for the automod service."""
|
||||
|
||||
import pytest
|
||||
|
||||
from guardden.models import BannedWord
|
||||
from guardden.services.automod import AutomodService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def automod() -> AutomodService:
|
||||
"""Create an automod service instance."""
|
||||
return AutomodService()
|
||||
|
||||
|
||||
class TestBannedWords:
|
||||
"""Tests for banned word filtering."""
|
||||
|
||||
def test_simple_match(self, automod: AutomodService) -> None:
|
||||
"""Test simple text matching."""
|
||||
banned = [_make_banned_word("badword")]
|
||||
result = automod.check_banned_words("This contains badword in it", banned)
|
||||
assert result is not None
|
||||
assert result.should_delete
|
||||
|
||||
def test_case_insensitive(self, automod: AutomodService) -> None:
|
||||
"""Test case insensitive matching."""
|
||||
banned = [_make_banned_word("BadWord")]
|
||||
result = automod.check_banned_words("this contains BADWORD here", banned)
|
||||
assert result is not None
|
||||
|
||||
def test_no_match(self, automod: AutomodService) -> None:
|
||||
"""Test no match returns None."""
|
||||
banned = [_make_banned_word("badword")]
|
||||
result = automod.check_banned_words("This is a clean message", banned)
|
||||
assert result is None
|
||||
|
||||
def test_regex_pattern(self, automod: AutomodService) -> None:
|
||||
"""Test regex pattern matching."""
|
||||
banned = [_make_banned_word(r"bad\w+", is_regex=True)]
|
||||
result = automod.check_banned_words("This is badword and badstuff", banned)
|
||||
assert result is not None
|
||||
|
||||
def test_action_warn(self, automod: AutomodService) -> None:
|
||||
"""Test warn action is set."""
|
||||
banned = [_make_banned_word("badword", action="warn")]
|
||||
result = automod.check_banned_words("badword", banned)
|
||||
assert result is not None
|
||||
assert result.should_warn
|
||||
|
||||
def test_action_strike(self, automod: AutomodService) -> None:
|
||||
"""Test strike action is set."""
|
||||
banned = [_make_banned_word("badword", action="strike")]
|
||||
result = automod.check_banned_words("badword", banned)
|
||||
assert result is not None
|
||||
assert result.should_strike
|
||||
|
||||
|
||||
class TestScamDetection:
|
||||
"""Tests for scam/phishing detection."""
|
||||
|
||||
def test_discord_nitro_scam(self, automod: AutomodService) -> None:
|
||||
"""Test detection of fake Discord Nitro links."""
|
||||
result = automod.check_scam_links("Free nitro at discord-nitro.gift")
|
||||
assert result is not None
|
||||
assert result.should_delete
|
||||
|
||||
def test_steam_scam(self, automod: AutomodService) -> None:
|
||||
"""Test detection of Steam scam patterns."""
|
||||
result = automod.check_scam_links("Check out this steam-community-giveaway.xyz")
|
||||
assert result is not None
|
||||
|
||||
def test_legitimate_discord_link(self, automod: AutomodService) -> None:
|
||||
"""Test that legitimate Discord links pass."""
|
||||
result = automod.check_scam_links("Join us at discord.gg/example")
|
||||
assert result is None
|
||||
|
||||
def test_suspicious_tld_with_keyword(self, automod: AutomodService) -> None:
|
||||
"""Test suspicious TLD with impersonation keyword."""
|
||||
result = automod.check_scam_links("Visit discord-verify.xyz to claim")
|
||||
assert result is not None
|
||||
|
||||
def test_normal_url(self, automod: AutomodService) -> None:
|
||||
"""Test normal URLs pass."""
|
||||
result = automod.check_scam_links("Check out https://github.com/example")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestInviteLinks:
|
||||
"""Tests for Discord invite link detection."""
|
||||
|
||||
def test_discord_gg_invite(self, automod: AutomodService) -> None:
|
||||
"""Test discord.gg invite detection."""
|
||||
result = automod.check_invite_links("Join discord.gg/example", allow_invites=False)
|
||||
assert result is not None
|
||||
assert result.should_delete
|
||||
|
||||
def test_discordapp_invite(self, automod: AutomodService) -> None:
|
||||
"""Test discordapp.com invite detection."""
|
||||
result = automod.check_invite_links(
|
||||
"Join https://discordapp.com/invite/abc123", allow_invites=False
|
||||
)
|
||||
assert result is not None
|
||||
|
||||
def test_allowed_invites(self, automod: AutomodService) -> None:
|
||||
"""Test invites pass when allowed."""
|
||||
result = automod.check_invite_links("Join discord.gg/example", allow_invites=True)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestCapsDetection:
|
||||
"""Tests for excessive caps detection."""
|
||||
|
||||
def test_excessive_caps(self, automod: AutomodService) -> None:
|
||||
"""Test detection of all caps message."""
|
||||
result = automod.check_all_caps("THIS IS ALL CAPS MESSAGE HERE")
|
||||
assert result is not None
|
||||
|
||||
def test_normal_caps(self, automod: AutomodService) -> None:
|
||||
"""Test normal message passes."""
|
||||
result = automod.check_all_caps("This is a Normal Message with Some Caps")
|
||||
assert result is None
|
||||
|
||||
def test_short_message_ignored(self, automod: AutomodService) -> None:
|
||||
"""Test short messages are ignored."""
|
||||
result = automod.check_all_caps("HI THERE")
|
||||
assert result is None
|
||||
|
||||
|
||||
class MockBannedWord:
|
||||
"""Mock BannedWord for testing without database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pattern: str,
|
||||
is_regex: bool = False,
|
||||
action: str = "delete",
|
||||
) -> None:
|
||||
self.id = 1
|
||||
self.guild_id = 123
|
||||
self.pattern = pattern
|
||||
self.is_regex = is_regex
|
||||
self.action = action
|
||||
self.reason = None
|
||||
self.added_by = 456
|
||||
|
||||
|
||||
def _make_banned_word(
|
||||
pattern: str,
|
||||
is_regex: bool = False,
|
||||
action: str = "delete",
|
||||
) -> MockBannedWord:
|
||||
"""Create a mock BannedWord object for testing."""
|
||||
return MockBannedWord(pattern, is_regex, action)
|
||||
130
tests/test_ratelimit.py
Normal file
130
tests/test_ratelimit.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Tests for rate limiting service."""
|
||||
|
||||
import pytest
|
||||
|
||||
from guardden.services.ratelimit import (
|
||||
RateLimitBucket,
|
||||
RateLimitConfig,
|
||||
RateLimiter,
|
||||
RateLimitScope,
|
||||
)
|
||||
|
||||
|
||||
class TestRateLimitBucket:
|
||||
"""Tests for RateLimitBucket."""
|
||||
|
||||
def test_not_limited_initially(self) -> None:
|
||||
"""Test bucket is not limited when empty."""
|
||||
bucket = RateLimitBucket(max_requests=3, window_seconds=60)
|
||||
assert bucket.is_limited() is False
|
||||
assert bucket.remaining() == 3
|
||||
|
||||
def test_limited_after_max_requests(self) -> None:
|
||||
"""Test bucket is limited after max requests."""
|
||||
bucket = RateLimitBucket(max_requests=3, window_seconds=60)
|
||||
|
||||
for _ in range(3):
|
||||
bucket.record()
|
||||
|
||||
assert bucket.is_limited() is True
|
||||
assert bucket.remaining() == 0
|
||||
|
||||
def test_remaining_decreases(self) -> None:
|
||||
"""Test remaining count decreases."""
|
||||
bucket = RateLimitBucket(max_requests=5, window_seconds=60)
|
||||
|
||||
bucket.record()
|
||||
assert bucket.remaining() == 4
|
||||
|
||||
bucket.record()
|
||||
assert bucket.remaining() == 3
|
||||
|
||||
|
||||
class TestRateLimiter:
|
||||
"""Tests for RateLimiter."""
|
||||
|
||||
@pytest.fixture
|
||||
def limiter(self) -> RateLimiter:
|
||||
return RateLimiter()
|
||||
|
||||
def test_check_not_limited(self, limiter: RateLimiter) -> None:
|
||||
"""Test check returns not limited for new bucket."""
|
||||
result = limiter.check("command", user_id=123, guild_id=456)
|
||||
assert result.is_limited is False
|
||||
|
||||
def test_acquire_records_request(self, limiter: RateLimiter) -> None:
|
||||
"""Test acquire records the request."""
|
||||
# Configure a simple limit
|
||||
limiter.configure(
|
||||
"test",
|
||||
RateLimitConfig(max_requests=2, window_seconds=60, scope=RateLimitScope.USER),
|
||||
)
|
||||
|
||||
result1 = limiter.acquire("test", user_id=123)
|
||||
assert result1.is_limited is False
|
||||
assert result1.remaining == 1
|
||||
|
||||
result2 = limiter.acquire("test", user_id=123)
|
||||
assert result2.is_limited is False
|
||||
assert result2.remaining == 0
|
||||
|
||||
result3 = limiter.acquire("test", user_id=123)
|
||||
assert result3.is_limited is True
|
||||
|
||||
def test_different_scopes(self, limiter: RateLimiter) -> None:
|
||||
"""Test different scopes create different buckets."""
|
||||
# User scope - shared across guilds
|
||||
limiter.configure(
|
||||
"user_action",
|
||||
RateLimitConfig(max_requests=1, window_seconds=60, scope=RateLimitScope.USER),
|
||||
)
|
||||
|
||||
limiter.acquire("user_action", user_id=123, guild_id=1)
|
||||
result = limiter.acquire("user_action", user_id=123, guild_id=2)
|
||||
assert result.is_limited is True # Same user, different guild
|
||||
|
||||
# Member scope - per guild
|
||||
limiter.configure(
|
||||
"member_action",
|
||||
RateLimitConfig(max_requests=1, window_seconds=60, scope=RateLimitScope.MEMBER),
|
||||
)
|
||||
|
||||
limiter.acquire("member_action", user_id=456, guild_id=1)
|
||||
result = limiter.acquire("member_action", user_id=456, guild_id=2)
|
||||
assert result.is_limited is False # Same user, different guild = different bucket
|
||||
|
||||
def test_reset(self, limiter: RateLimiter) -> None:
|
||||
"""Test resetting a bucket."""
|
||||
limiter.configure(
|
||||
"test",
|
||||
RateLimitConfig(max_requests=1, window_seconds=60, scope=RateLimitScope.USER),
|
||||
)
|
||||
|
||||
limiter.acquire("test", user_id=123)
|
||||
assert limiter.acquire("test", user_id=123).is_limited is True
|
||||
|
||||
limiter.reset("test", user_id=123)
|
||||
assert limiter.acquire("test", user_id=123).is_limited is False
|
||||
|
||||
def test_unknown_action(self, limiter: RateLimiter) -> None:
|
||||
"""Test unknown action returns not limited."""
|
||||
result = limiter.acquire("unknown_action", user_id=123)
|
||||
assert result.is_limited is False
|
||||
assert result.remaining == 999
|
||||
|
||||
def test_guild_scope(self, limiter: RateLimiter) -> None:
|
||||
"""Test guild-scoped rate limiting."""
|
||||
limiter.configure(
|
||||
"guild_action",
|
||||
RateLimitConfig(max_requests=2, window_seconds=60, scope=RateLimitScope.GUILD),
|
||||
)
|
||||
|
||||
# Different users in same guild share the limit
|
||||
limiter.acquire("guild_action", user_id=1, guild_id=100)
|
||||
limiter.acquire("guild_action", user_id=2, guild_id=100)
|
||||
result = limiter.acquire("guild_action", user_id=3, guild_id=100)
|
||||
assert result.is_limited is True
|
||||
|
||||
# Different guild is not limited
|
||||
result = limiter.acquire("guild_action", user_id=1, guild_id=200)
|
||||
assert result.is_limited is False
|
||||
48
tests/test_utils.py
Normal file
48
tests/test_utils.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Tests for utility functions."""
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from guardden.cogs.moderation import parse_duration
|
||||
|
||||
|
||||
class TestParseDuration:
|
||||
"""Tests for the parse_duration function."""
|
||||
|
||||
def test_parse_seconds(self) -> None:
|
||||
"""Test parsing seconds."""
|
||||
assert parse_duration("30s") == timedelta(seconds=30)
|
||||
assert parse_duration("1s") == timedelta(seconds=1)
|
||||
|
||||
def test_parse_minutes(self) -> None:
|
||||
"""Test parsing minutes."""
|
||||
assert parse_duration("5m") == timedelta(minutes=5)
|
||||
assert parse_duration("30m") == timedelta(minutes=30)
|
||||
|
||||
def test_parse_hours(self) -> None:
|
||||
"""Test parsing hours."""
|
||||
assert parse_duration("1h") == timedelta(hours=1)
|
||||
assert parse_duration("24h") == timedelta(hours=24)
|
||||
|
||||
def test_parse_days(self) -> None:
|
||||
"""Test parsing days."""
|
||||
assert parse_duration("7d") == timedelta(days=7)
|
||||
assert parse_duration("1d") == timedelta(days=1)
|
||||
|
||||
def test_parse_weeks(self) -> None:
|
||||
"""Test parsing weeks."""
|
||||
assert parse_duration("2w") == timedelta(weeks=2)
|
||||
assert parse_duration("1w") == timedelta(weeks=1)
|
||||
|
||||
def test_invalid_format(self) -> None:
|
||||
"""Test invalid duration formats."""
|
||||
assert parse_duration("invalid") is None
|
||||
assert parse_duration("") is None
|
||||
assert parse_duration("10") is None
|
||||
assert parse_duration("abc") is None
|
||||
|
||||
def test_case_insensitive(self) -> None:
|
||||
"""Test that parsing is case insensitive."""
|
||||
assert parse_duration("1H") == timedelta(hours=1)
|
||||
assert parse_duration("30M") == timedelta(minutes=30)
|
||||
142
tests/test_verification.py
Normal file
142
tests/test_verification.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Tests for verification service."""
|
||||
|
||||
import pytest
|
||||
|
||||
from guardden.services.verification import (
|
||||
ButtonChallengeGenerator,
|
||||
CaptchaChallengeGenerator,
|
||||
ChallengeType,
|
||||
EmojiChallengeGenerator,
|
||||
MathChallengeGenerator,
|
||||
VerificationService,
|
||||
)
|
||||
|
||||
|
||||
class TestChallengeGenerators:
|
||||
"""Tests for challenge generators."""
|
||||
|
||||
def test_button_challenge(self) -> None:
|
||||
"""Test button challenge generation."""
|
||||
gen = ButtonChallengeGenerator()
|
||||
challenge = gen.generate()
|
||||
|
||||
assert challenge.challenge_type == ChallengeType.BUTTON
|
||||
assert challenge.answer == "verified"
|
||||
|
||||
def test_captcha_challenge(self) -> None:
|
||||
"""Test captcha challenge generation."""
|
||||
gen = CaptchaChallengeGenerator(length=6)
|
||||
challenge = gen.generate()
|
||||
|
||||
assert challenge.challenge_type == ChallengeType.CAPTCHA
|
||||
assert len(challenge.answer) == 6
|
||||
assert challenge.answer.isalnum()
|
||||
|
||||
def test_math_challenge(self) -> None:
|
||||
"""Test math challenge generation."""
|
||||
gen = MathChallengeGenerator()
|
||||
challenge = gen.generate()
|
||||
|
||||
assert challenge.challenge_type == ChallengeType.MATH
|
||||
# Answer should be a number
|
||||
assert challenge.answer.lstrip("-").isdigit()
|
||||
|
||||
def test_emoji_challenge(self) -> None:
|
||||
"""Test emoji challenge generation."""
|
||||
gen = EmojiChallengeGenerator()
|
||||
challenge = gen.generate()
|
||||
|
||||
assert challenge.challenge_type == ChallengeType.EMOJI
|
||||
assert len(challenge.options) > 0
|
||||
assert challenge.answer in challenge.options
|
||||
|
||||
|
||||
class TestVerificationService:
|
||||
"""Tests for verification service."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self) -> VerificationService:
|
||||
return VerificationService()
|
||||
|
||||
def test_create_challenge(self, service: VerificationService) -> None:
|
||||
"""Test creating a challenge."""
|
||||
pending = service.create_challenge(
|
||||
user_id=123,
|
||||
guild_id=456,
|
||||
challenge_type=ChallengeType.BUTTON,
|
||||
)
|
||||
|
||||
assert pending.user_id == 123
|
||||
assert pending.guild_id == 456
|
||||
assert pending.challenge.challenge_type == ChallengeType.BUTTON
|
||||
|
||||
def test_get_pending(self, service: VerificationService) -> None:
|
||||
"""Test retrieving pending verification."""
|
||||
service.create_challenge(123, 456, ChallengeType.BUTTON)
|
||||
|
||||
pending = service.get_pending(456, 123)
|
||||
assert pending is not None
|
||||
assert pending.user_id == 123
|
||||
|
||||
# Non-existent should return None
|
||||
assert service.get_pending(456, 999) is None
|
||||
|
||||
def test_verify_correct(self, service: VerificationService) -> None:
|
||||
"""Test successful verification."""
|
||||
service.create_challenge(123, 456, ChallengeType.BUTTON)
|
||||
|
||||
success, message = service.verify(456, 123, "verified")
|
||||
assert success is True
|
||||
assert "successful" in message.lower()
|
||||
|
||||
# Should be removed after success
|
||||
assert service.get_pending(456, 123) is None
|
||||
|
||||
def test_verify_incorrect(self, service: VerificationService) -> None:
|
||||
"""Test failed verification."""
|
||||
service.create_challenge(123, 456, ChallengeType.BUTTON)
|
||||
|
||||
success, message = service.verify(456, 123, "wrong")
|
||||
assert success is False
|
||||
assert "incorrect" in message.lower()
|
||||
|
||||
# Should still exist
|
||||
assert service.get_pending(456, 123) is not None
|
||||
|
||||
def test_verify_max_attempts(self, service: VerificationService) -> None:
|
||||
"""Test max attempts exceeded."""
|
||||
service.create_challenge(123, 456, ChallengeType.BUTTON)
|
||||
|
||||
# Use up all attempts
|
||||
for _ in range(3):
|
||||
service.verify(456, 123, "wrong")
|
||||
|
||||
success, message = service.verify(456, 123, "verified")
|
||||
assert success is False
|
||||
assert "too many" in message.lower()
|
||||
|
||||
def test_verify_no_pending(self, service: VerificationService) -> None:
|
||||
"""Test verification with no pending challenge."""
|
||||
success, message = service.verify(456, 123, "verified")
|
||||
assert success is False
|
||||
assert "no pending" in message.lower()
|
||||
|
||||
def test_cancel(self, service: VerificationService) -> None:
|
||||
"""Test canceling verification."""
|
||||
service.create_challenge(123, 456, ChallengeType.BUTTON)
|
||||
|
||||
assert service.cancel(456, 123) is True
|
||||
assert service.get_pending(456, 123) is None
|
||||
|
||||
# Cancel non-existent returns False
|
||||
assert service.cancel(456, 123) is False
|
||||
|
||||
def test_pending_count(self, service: VerificationService) -> None:
|
||||
"""Test pending count per guild."""
|
||||
service.create_challenge(1, 456, ChallengeType.BUTTON)
|
||||
service.create_challenge(2, 456, ChallengeType.BUTTON)
|
||||
service.create_challenge(3, 789, ChallengeType.BUTTON)
|
||||
|
||||
assert service.get_pending_count(456) == 2
|
||||
assert service.get_pending_count(789) == 1
|
||||
assert service.get_pending_count(999) == 0
|
||||
Reference in New Issue
Block a user