From 4e16777f25434d5f62ebfa65002187cf3879e469 Mon Sep 17 00:00:00 2001 From: latte Date: Fri, 16 Jan 2026 19:27:48 +0100 Subject: [PATCH] Implement GuardDen Discord moderation bot Features: - Core moderation: warn, kick, ban, timeout, strike system - Automod: banned words filter, scam detection, anti-spam, link filtering - AI moderation: Claude/OpenAI integration, NSFW detection, phishing analysis - Verification system: button, captcha, math, emoji challenges - Rate limiting system with configurable scopes - Event logging: joins, leaves, message edits/deletes, voice activity - Per-guild configuration with caching - Docker deployment support Bug fixes applied: - Fixed await on session.delete() in guild_config.py - Fixed memory leak in AI moderation message tracking (use deque) - Added error handling to bot shutdown - Added error handling to timeout command - Removed unused Literal import - Added prefix validation - Added image analysis limit (3 per message) - Fixed test mock for SQLAlchemy model --- .env.example | 21 + .gitignore | 56 +++ CLAUDE.md | 103 ++++ Dockerfile | 27 + README.md | 319 +++++++++++- alembic.ini | 43 ++ docker-compose.yml | 44 ++ migrations/env.py | 65 +++ migrations/script.py.mako | 26 + pyproject.toml | 98 ++++ src/guardden/__init__.py | 3 + src/guardden/__main__.py | 40 ++ src/guardden/bot.py | 131 +++++ src/guardden/cogs/__init__.py | 1 + src/guardden/cogs/admin.py | 255 ++++++++++ src/guardden/cogs/ai_moderation.py | 366 ++++++++++++++ src/guardden/cogs/automod.py | 267 ++++++++++ src/guardden/cogs/events.py | 237 +++++++++ src/guardden/cogs/moderation.py | 466 ++++++++++++++++++ src/guardden/cogs/verification.py | 423 ++++++++++++++++ src/guardden/config.py | 50 ++ src/guardden/models/__init__.py | 15 + src/guardden/models/base.py | 32 ++ src/guardden/models/guild.py | 117 +++++ src/guardden/models/moderation.py | 101 ++++ src/guardden/services/__init__.py | 16 + src/guardden/services/ai/__init__.py | 6 + .../services/ai/anthropic_provider.py | 261 ++++++++++ src/guardden/services/ai/base.py | 149 ++++++ src/guardden/services/ai/factory.py | 67 +++ src/guardden/services/ai/openai_provider.py | 213 ++++++++ src/guardden/services/automod.py | 301 +++++++++++ src/guardden/services/database.py | 99 ++++ src/guardden/services/guild_config.py | 145 ++++++ src/guardden/services/ratelimit.py | 300 +++++++++++ src/guardden/services/verification.py | 300 +++++++++++ src/guardden/utils/__init__.py | 5 + src/guardden/utils/logging.py | 27 + tests/__init__.py | 1 + tests/conftest.py | 15 + tests/test_ai.py | 119 +++++ tests/test_automod.py | 153 ++++++ tests/test_ratelimit.py | 130 +++++ tests/test_utils.py | 48 ++ tests/test_verification.py | 142 ++++++ 45 files changed, 5802 insertions(+), 1 deletion(-) create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 CLAUDE.md create mode 100644 Dockerfile create mode 100644 alembic.ini create mode 100644 docker-compose.yml create mode 100644 migrations/env.py create mode 100644 migrations/script.py.mako create mode 100644 pyproject.toml create mode 100644 src/guardden/__init__.py create mode 100644 src/guardden/__main__.py create mode 100644 src/guardden/bot.py create mode 100644 src/guardden/cogs/__init__.py create mode 100644 src/guardden/cogs/admin.py create mode 100644 src/guardden/cogs/ai_moderation.py create mode 100644 src/guardden/cogs/automod.py create mode 100644 src/guardden/cogs/events.py create mode 100644 src/guardden/cogs/moderation.py create mode 100644 src/guardden/cogs/verification.py create mode 100644 src/guardden/config.py create mode 100644 src/guardden/models/__init__.py create mode 100644 src/guardden/models/base.py create mode 100644 src/guardden/models/guild.py create mode 100644 src/guardden/models/moderation.py create mode 100644 src/guardden/services/__init__.py create mode 100644 src/guardden/services/ai/__init__.py create mode 100644 src/guardden/services/ai/anthropic_provider.py create mode 100644 src/guardden/services/ai/base.py create mode 100644 src/guardden/services/ai/factory.py create mode 100644 src/guardden/services/ai/openai_provider.py create mode 100644 src/guardden/services/automod.py create mode 100644 src/guardden/services/database.py create mode 100644 src/guardden/services/guild_config.py create mode 100644 src/guardden/services/ratelimit.py create mode 100644 src/guardden/services/verification.py create mode 100644 src/guardden/utils/__init__.py create mode 100644 src/guardden/utils/logging.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_ai.py create mode 100644 tests/test_automod.py create mode 100644 tests/test_ratelimit.py create mode 100644 tests/test_utils.py create mode 100644 tests/test_verification.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..ad17a03 --- /dev/null +++ b/.env.example @@ -0,0 +1,21 @@ +# Discord Bot Configuration +GUARDDEN_DISCORD_TOKEN=your_discord_bot_token_here +GUARDDEN_DISCORD_PREFIX=! + +# Database Configuration (for local development without Docker) +GUARDDEN_DATABASE_URL=postgresql://guardden:guardden@localhost:5432/guardden + +# Logging +GUARDDEN_LOG_LEVEL=INFO + +# AI Configuration (optional) +# Options: none, anthropic, openai +GUARDDEN_AI_PROVIDER=none + +# Anthropic API key (required if AI_PROVIDER=anthropic) +# Get your key at: https://console.anthropic.com/ +GUARDDEN_ANTHROPIC_API_KEY= + +# OpenAI API key (required if AI_PROVIDER=openai) +# Get your key at: https://platform.openai.com/api-keys +GUARDDEN_OPENAI_API_KEY= diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..31f9ad3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,56 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual environments +venv/ +ENV/ +env/ +.venv/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# Environment +.env +.env.local + +# Data +data/ +*.db +*.sqlite3 + +# Logs +*.log +logs/ + +# Testing +.coverage +htmlcov/ +.pytest_cache/ +.mypy_cache/ + +# Docker +docker-compose.override.yml diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..285a090 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,103 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +GuardDen is a Discord moderation bot built with discord.py, PostgreSQL, and optional AI integration (Claude/OpenAI). Self-hosted with Docker support. + +## Commands + +```bash +# Install dependencies +pip install -e ".[dev,ai]" + +# Run the bot +python -m guardden + +# Run tests +pytest + +# Run single test +pytest tests/test_verification.py::TestVerificationService::test_verify_correct + +# Lint and format +ruff check src tests +ruff format src tests + +# Type checking +mypy src + +# Docker deployment +docker compose up -d +``` + +## Architecture + +- `src/guardden/bot.py` - Main bot class (`GuardDen`) extending `commands.Bot`, manages lifecycle and services +- `src/guardden/config.py` - Pydantic settings loaded from environment variables (prefix: `GUARDDEN_`) +- `src/guardden/models/` - SQLAlchemy 2.0 async models for PostgreSQL +- `src/guardden/services/` - Business logic (database, guild config, automod, AI, verification, rate limiting) +- `src/guardden/cogs/` - Discord command groups (events, moderation, admin, automod, ai_moderation, verification) + +## Key Patterns + +- All database operations use async SQLAlchemy with `asyncpg` +- Guild configurations are cached in `GuildConfigService._cache` +- Discord snowflake IDs stored as `BigInteger` in PostgreSQL +- Moderation actions logged to `ModerationLog` table with automatic strike escalation +- Environment variables: `GUARDDEN_DISCORD_TOKEN`, `GUARDDEN_DATABASE_URL` + +## Automod System + +- `AutomodService` in `services/automod.py` handles rule-based content filtering +- Checks run in order: banned words → scam links → spam → invite links +- Spam tracking uses per-guild, per-user trackers with automatic cleanup +- Scam detection uses compiled regex patterns in `SCAM_PATTERNS` list +- Results return `AutomodResult` dataclass with actions to take + +## AI Moderation System + +- `services/ai/` contains provider abstraction and implementations +- `AIProvider` base class defines interface: `moderate_text()`, `analyze_image()`, `analyze_phishing()` +- `AnthropicProvider` and `OpenAIProvider` implement the interface +- `NullProvider` used when AI is disabled (returns empty results) +- Factory pattern via `create_ai_provider(provider, api_key)` +- `ModerationResult` includes severity scoring based on confidence + category weights +- Sensitivity setting (0-100) adjusts thresholds per guild + +## Verification System + +- `VerificationService` in `services/verification.py` manages challenges +- Challenge types: button, captcha, math, emoji (via `ChallengeGenerator` classes) +- `PendingVerification` tracks user challenges with expiry and attempt limits +- Discord UI components in `cogs/verification.py`: `VerifyButton`, `EmojiButton`, `CaptchaModal` +- Background task cleans up expired verifications every 5 minutes + +## Rate Limiting System + +- `RateLimiter` in `services/ratelimit.py` provides general-purpose rate limiting +- Scopes: USER (global), MEMBER (per-guild), CHANNEL, GUILD +- `@ratelimit()` decorator for easy command rate limiting +- `get_rate_limiter()` returns singleton instance +- Default limits configured for commands, moderation, verification, messages + +## Adding New Cogs + +1. Create file in `src/guardden/cogs/` +2. Implement `setup(bot)` async function +3. Add cog path to `_load_cogs()` in `bot.py` + +## Adding New AI Provider + +1. Create `src/guardden/services/ai/newprovider.py` +2. Implement `AIProvider` abstract class +3. Add to factory in `services/ai/factory.py` +4. Add config option in `config.py` + +## Adding New Challenge Type + +1. Create new `ChallengeGenerator` subclass in `services/verification.py` +2. Add to `ChallengeType` enum +3. Register in `VerificationService._generators` +4. Create corresponding UI components in `cogs/verification.py` if needed diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..33b7172 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,27 @@ +FROM python:3.11-slim + +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + gcc \ + libpq-dev \ + && rm -rf /var/lib/apt/lists/* + +# Copy dependency files +COPY pyproject.toml ./ + +# Install Python dependencies +RUN pip install --no-cache-dir -e . + +# Copy application code +COPY src/ ./src/ +COPY migrations/ ./migrations/ +COPY alembic.ini ./ + +# Create non-root user +RUN useradd -m -u 1000 guardden && chown -R guardden:guardden /app +USER guardden + +# Run the bot +CMD ["python", "-m", "guardden"] diff --git a/README.md b/README.md index 7edbbe4..ac880cd 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,320 @@ # GuardDen -GuardDen is a comprehensive Discord moderation bot designed to protect your community while maintaining a warm, welcoming environment. Built with privacy and self-hosting in mind, GuardDen combines AI-powered content filtering with traditional moderation tools to create a safe space for your members. \ No newline at end of file +GuardDen is a comprehensive Discord moderation bot designed to protect your community while maintaining a warm, welcoming environment. Built with privacy and self-hosting in mind, GuardDen combines AI-powered content filtering with traditional moderation tools to create a safe space for your members. + +## Features + +### Core Moderation +- **Warn, Kick, Ban, Timeout** - Standard moderation commands with logging +- **Strike System** - Configurable point-based system with automatic escalation +- **Moderation History** - Track all actions taken against users +- **Bulk Message Deletion** - Purge up to 100 messages at once + +### Automod +- **Banned Words Filter** - Block words/phrases with regex support +- **Scam Detection** - Automatic detection of phishing/scam links +- **Anti-Spam** - Rate limiting, duplicate detection, mass mention protection +- **Link Filtering** - Block Discord invites and suspicious URLs + +### AI Moderation +- **Text Analysis** - AI-powered content moderation using Claude or GPT +- **NSFW Image Detection** - Automatic flagging of inappropriate images +- **Phishing Analysis** - AI-enhanced detection of scam URLs +- **Configurable Sensitivity** - Adjust strictness per server (0-100) + +### Verification System +- **Multiple Challenge Types** - Button, captcha, math problems, emoji selection +- **Automatic New Member Verification** - Challenge users on join +- **Configurable Verified Role** - Auto-assign role on successful verification +- **Rate Limited** - Prevents verification spam + +### Logging +- Member joins/leaves +- Message edits and deletions +- Voice channel activity +- Ban/unban events +- All moderation actions + +## Quick Start + +### Prerequisites +- Python 3.11+ +- PostgreSQL 15+ +- Discord Bot Token ([Discord Developer Portal](https://discord.com/developers/applications)) +- (Optional) Anthropic or OpenAI API key for AI features + +### Docker Deployment (Recommended) + +1. Clone the repository: + ```bash + git clone https://github.com/yourusername/guardden.git + cd guardden + ``` + +2. Create your environment file: + ```bash + cp .env.example .env + # Edit .env and add your Discord token + ``` + +3. Start with Docker Compose: + ```bash + docker compose up -d + ``` + +### Local Development + +1. Create a virtual environment: + ```bash + python -m venv venv + source venv/bin/activate # On Windows: venv\Scripts\activate + ``` + +2. Install dependencies: + ```bash + pip install -e ".[dev,ai]" + ``` + +3. Set up environment variables: + ```bash + cp .env.example .env + # Edit .env with your configuration + ``` + +4. Start PostgreSQL (or use Docker): + ```bash + docker compose up db -d + ``` + +5. Run the bot: + ```bash + python -m guardden + ``` + +## Configuration + +### Environment Variables + +| Variable | Description | Default | +|----------|-------------|---------| +| `GUARDDEN_DISCORD_TOKEN` | Your Discord bot token | Required | +| `GUARDDEN_DISCORD_PREFIX` | Default command prefix | `!` | +| `GUARDDEN_DATABASE_URL` | PostgreSQL connection URL | `postgresql://guardden:guardden@localhost:5432/guardden` | +| `GUARDDEN_LOG_LEVEL` | Logging level | `INFO` | +| `GUARDDEN_AI_PROVIDER` | AI provider (anthropic/openai/none) | `none` | +| `GUARDDEN_ANTHROPIC_API_KEY` | Anthropic API key (if using Claude) | - | +| `GUARDDEN_OPENAI_API_KEY` | OpenAI API key (if using GPT) | - | + +### Per-Guild Settings + +Each server can configure: +- Command prefix +- Log channels (general and moderation) +- Welcome channel +- Mute role and verified role +- Automod toggles (spam, links, banned words) +- Strike action thresholds +- AI moderation settings (enabled, sensitivity, NSFW detection) +- Verification settings (type, enabled) + +## Commands + +### Moderation + +| Command | Permission | Description | +|---------|------------|-------------| +| `!warn [reason]` | Kick Members | Warn a user | +| `!strike [points] [reason]` | Kick Members | Add strikes to a user | +| `!strikes ` | Kick Members | View user's strikes | +| `!timeout [reason]` | Moderate Members | Timeout a user (e.g., 1h, 30m, 7d) | +| `!untimeout ` | Moderate Members | Remove timeout | +| `!kick [reason]` | Kick Members | Kick a user | +| `!ban [reason]` | Ban Members | Ban a user | +| `!unban [reason]` | Ban Members | Unban a user by ID | +| `!purge ` | Manage Messages | Delete multiple messages (max 100) | +| `!modlogs ` | Kick Members | View moderation history | + +### Configuration (Admin only) + +| Command | Description | +|---------|-------------| +| `!config` | View current configuration | +| `!config prefix ` | Set command prefix | +| `!config logchannel [#channel]` | Set general log channel | +| `!config modlogchannel [#channel]` | Set moderation log channel | +| `!config welcomechannel [#channel]` | Set welcome channel | +| `!config muterole [@role]` | Set mute role | +| `!config automod ` | Toggle automod | +| `!config antispam ` | Toggle anti-spam | +| `!config linkfilter ` | Toggle link filtering | + +### Banned Words + +| Command | Description | +|---------|-------------| +| `!bannedwords` | List all banned words | +| `!bannedwords add [action] [is_regex]` | Add a banned word | +| `!bannedwords remove ` | Remove a banned word by ID | + +### Automod + +| Command | Description | +|---------|-------------| +| `!automod` | View automod status | +| `!automod test ` | Test text against filters | + +### AI Moderation (Admin only) + +| Command | Description | +|---------|-------------| +| `!ai` | View AI moderation settings | +| `!ai enable` | Enable AI moderation | +| `!ai disable` | Disable AI moderation | +| `!ai sensitivity <0-100>` | Set AI sensitivity level | +| `!ai nsfw ` | Toggle NSFW image detection | +| `!ai analyze ` | Test AI analysis on text | + +### Verification (Admin only) + +| Command | Description | +|---------|-------------| +| `!verify` | Request verification (for users) | +| `!verify setup` | View verification setup status | +| `!verify enable` | Enable verification for new members | +| `!verify disable` | Disable verification | +| `!verify role @role` | Set the verified role | +| `!verify type ` | Set verification type (button/captcha/math/emoji) | +| `!verify test [type]` | Test a verification challenge | +| `!verify reset @user` | Reset verification for a user | + +## Project Structure + +``` +guardden/ +├── src/guardden/ +│ ├── bot.py # Main bot class +│ ├── config.py # Settings management +│ ├── cogs/ # Discord command groups +│ │ ├── admin.py # Configuration commands +│ │ ├── ai_moderation.py # AI-powered moderation +│ │ ├── automod.py # Automatic moderation +│ │ ├── events.py # Event logging +│ │ ├── moderation.py # Moderation commands +│ │ └── verification.py # Member verification +│ ├── models/ # Database models +│ │ ├── guild.py # Guild settings, banned words +│ │ └── moderation.py # Logs, strikes, notes +│ └── services/ # Business logic +│ ├── ai/ # AI provider implementations +│ ├── automod.py # Content filtering +│ ├── database.py # DB connections +│ ├── guild_config.py # Config caching +│ ├── ratelimit.py # Rate limiting +│ └── verification.py # Verification challenges +├── tests/ # Test suite +├── migrations/ # Database migrations +├── docker-compose.yml # Docker deployment +└── pyproject.toml # Dependencies +``` + +## Verification System + +GuardDen includes a verification system to protect your server from bots and raids. + +### Challenge Types + +| Type | Description | +|------|-------------| +| `button` | Simple button click (default, easiest) | +| `captcha` | Text-based captcha code entry | +| `math` | Solve a simple math problem | +| `emoji` | Select the correct emoji from options | + +### Setup + +1. Create a verified role in your server +2. Configure the role permissions (verified members get full access) +3. Set up verification: + ``` + !verify role @Verified + !verify type captcha + !verify enable + ``` + +### How It Works + +1. New member joins the server +2. Bot sends verification challenge via DM (or channel if DMs disabled) +3. Member completes the challenge +4. Bot assigns the verified role +5. Member gains access to the server + +## AI Moderation + +GuardDen supports AI-powered content moderation using either Anthropic's Claude or OpenAI's GPT models. + +### Setup + +1. Set the AI provider in your environment: + ```bash + GUARDDEN_AI_PROVIDER=anthropic # or "openai" + GUARDDEN_ANTHROPIC_API_KEY=sk-ant-... # if using Claude + GUARDDEN_OPENAI_API_KEY=sk-... # if using OpenAI + ``` + +2. Enable AI moderation per server: + ``` + !ai enable + !ai sensitivity 50 # 0=lenient, 100=strict + !ai nsfw true # Enable NSFW image detection + ``` + +### Content Categories + +The AI analyzes content for: +- **Harassment** - Personal attacks, bullying +- **Hate Speech** - Discrimination, slurs +- **Sexual Content** - Explicit material +- **Violence** - Threats, graphic content +- **Self-Harm** - Suicide/self-injury content +- **Scams** - Phishing, fraud attempts +- **Spam** - Promotional, low-quality content + +### How It Works + +1. Messages are analyzed by the AI provider +2. Results include confidence scores and severity ratings +3. Actions are taken based on guild sensitivity settings +4. All AI actions are logged to the mod log channel + +## Development + +### Running Tests + +```bash +pytest +pytest -v # Verbose output +pytest tests/test_automod.py # Specific file +pytest -k "test_scam" # Filter by name +``` + +### Code Quality + +```bash +ruff check src tests # Linting +ruff format src tests # Formatting +mypy src # Type checking +``` + +## License + +MIT License - see LICENSE file for details. + +## Roadmap + +- [x] AI-powered content moderation (Claude/OpenAI integration) +- [x] NSFW image detection +- [x] Verification/captcha system +- [x] Rate limiting +- [ ] Voice channel moderation +- [ ] Web dashboard diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..9d3bae4 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,43 @@ +[alembic] +script_location = migrations +prepend_sys_path = . +version_path_separator = os + +# Use async driver +sqlalchemy.url = postgresql+asyncpg://guardden:guardden@localhost:5432/guardden + +[post_write_hooks] + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..90a96a1 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,44 @@ +services: + bot: + build: . + container_name: guardden-bot + restart: unless-stopped + depends_on: + db: + condition: service_healthy + environment: + - GUARDDEN_DISCORD_TOKEN=${GUARDDEN_DISCORD_TOKEN} + - GUARDDEN_DATABASE_URL=postgresql://guardden:guardden@db:5432/guardden + - GUARDDEN_LOG_LEVEL=${GUARDDEN_LOG_LEVEL:-INFO} + - GUARDDEN_AI_PROVIDER=${GUARDDEN_AI_PROVIDER:-none} + - GUARDDEN_ANTHROPIC_API_KEY=${GUARDDEN_ANTHROPIC_API_KEY:-} + - GUARDDEN_OPENAI_API_KEY=${GUARDDEN_OPENAI_API_KEY:-} + volumes: + - ./data:/app/data + networks: + - guardden + + db: + image: postgres:15-alpine + container_name: guardden-db + restart: unless-stopped + environment: + - POSTGRES_USER=guardden + - POSTGRES_PASSWORD=guardden + - POSTGRES_DB=guardden + volumes: + - postgres_data:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U guardden -d guardden"] + interval: 5s + timeout: 5s + retries: 5 + networks: + - guardden + +networks: + guardden: + driver: bridge + +volumes: + postgres_data: diff --git a/migrations/env.py b/migrations/env.py new file mode 100644 index 0000000..7e6f93b --- /dev/null +++ b/migrations/env.py @@ -0,0 +1,65 @@ +"""Alembic environment configuration.""" + +import asyncio +from logging.config import fileConfig + +from alembic import context +from sqlalchemy import pool +from sqlalchemy.engine import Connection +from sqlalchemy.ext.asyncio import async_engine_from_config + +from guardden.models import Base + +config = context.config + +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +target_metadata = Base.metadata + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode.""" + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection: Connection) -> None: + """Run migrations with the given connection.""" + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """Run migrations in 'online' mode with async engine.""" + connectable = async_engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode.""" + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/migrations/script.py.mako b/migrations/script.py.mako new file mode 100644 index 0000000..fbc4b07 --- /dev/null +++ b/migrations/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..57b370e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,98 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "guardden" +version = "0.1.0" +description = "A comprehensive Discord moderation bot with AI-powered content filtering" +readme = "README.md" +license = {text = "MIT"} +requires-python = ">=3.11" +authors = [ + {name = "GuardDen Team"} +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] + +dependencies = [ + "discord.py>=2.3.0", + "asyncpg>=0.29.0", + "pydantic>=2.5.0", + "pydantic-settings>=2.1.0", + "aiohttp>=3.9.0", + "python-dotenv>=1.0.0", + "alembic>=1.13.0", + "sqlalchemy>=2.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.4.0", + "pytest-asyncio>=0.23.0", + "pytest-cov>=4.1.0", + "ruff>=0.1.0", + "mypy>=1.7.0", + "pre-commit>=3.6.0", +] +ai = [ + "anthropic>=0.18.0", + "openai>=1.10.0", + "pillow>=10.2.0", +] +voice = [ + "speechrecognition>=3.10.0", + "pydub>=0.25.0", +] + +[project.scripts] +guardden = "guardden.__main__:main" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.ruff] +target-version = "py311" +line-length = 100 +src = ["src", "tests"] + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # Pyflakes + "I", # isort + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "UP", # pyupgrade + "ARG", # flake8-unused-arguments + "SIM", # flake8-simplify +] +ignore = [ + "E501", # line too long (handled by formatter) + "B008", # do not perform function calls in argument defaults +] + +[tool.ruff.lint.isort] +known-first-party = ["guardden"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +addopts = "-v --tb=short" + +[tool.mypy] +python_version = "3.11" +strict = true +warn_return_any = true +warn_unused_ignores = true +plugins = ["pydantic.mypy"] + +[[tool.mypy.overrides]] +module = ["discord.*", "asyncpg.*"] +ignore_missing_imports = true diff --git a/src/guardden/__init__.py b/src/guardden/__init__.py new file mode 100644 index 0000000..610ca5c --- /dev/null +++ b/src/guardden/__init__.py @@ -0,0 +1,3 @@ +"""GuardDen - A comprehensive Discord moderation bot.""" + +__version__ = "0.1.0" diff --git a/src/guardden/__main__.py b/src/guardden/__main__.py new file mode 100644 index 0000000..9accb63 --- /dev/null +++ b/src/guardden/__main__.py @@ -0,0 +1,40 @@ +"""Entry point for GuardDen bot.""" + +import asyncio +import logging +import sys + +from guardden.bot import GuardDen +from guardden.config import get_settings +from guardden.utils import setup_logging + + +def main() -> None: + """Run the GuardDen bot.""" + try: + settings = get_settings() + except Exception as e: + print(f"Failed to load configuration: {e}", file=sys.stderr) + print("Make sure GUARDDEN_DISCORD_TOKEN is set.", file=sys.stderr) + sys.exit(1) + + setup_logging(settings.log_level) + logger = logging.getLogger(__name__) + + bot = GuardDen(settings) + + async def runner() -> None: + async with bot: + await bot.start(settings.discord_token.get_secret_value()) + + try: + asyncio.run(runner()) + except KeyboardInterrupt: + logger.info("Received keyboard interrupt, shutting down...") + except Exception as e: + logger.exception(f"Fatal error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/src/guardden/bot.py b/src/guardden/bot.py new file mode 100644 index 0000000..e249ba2 --- /dev/null +++ b/src/guardden/bot.py @@ -0,0 +1,131 @@ +"""Main bot class for GuardDen.""" + +import logging +from typing import TYPE_CHECKING + +import discord +from discord.ext import commands + +from guardden.config import Settings +from guardden.services.ai import AIProvider, create_ai_provider +from guardden.services.database import Database + +if TYPE_CHECKING: + from guardden.services.guild_config import GuildConfigService + +logger = logging.getLogger(__name__) + + +class GuardDen(commands.Bot): + """The main GuardDen Discord bot.""" + + def __init__(self, settings: Settings) -> None: + self.settings = settings + + intents = discord.Intents.default() + intents.message_content = True + intents.members = True + intents.voice_states = True + + super().__init__( + command_prefix=self._get_prefix, + intents=intents, + help_command=commands.DefaultHelpCommand(), + ) + + # Services + self.database = Database(settings) + self.guild_config: "GuildConfigService | None" = None + self.ai_provider: AIProvider | None = None + + async def _get_prefix(self, bot: "GuardDen", message: discord.Message) -> list[str]: + """Get the command prefix for a guild.""" + if not message.guild: + return [self.settings.discord_prefix] + + if self.guild_config: + config = await self.guild_config.get_config(message.guild.id) + if config: + return [config.prefix] + + return [self.settings.discord_prefix] + + async def setup_hook(self) -> None: + """Called when the bot is starting up.""" + logger.info("Starting GuardDen setup...") + + # Connect to database + await self.database.connect() + await self.database.create_tables() + + # Initialize services + from guardden.services.guild_config import GuildConfigService + + self.guild_config = GuildConfigService(self.database) + + # Initialize AI provider + api_key = None + if self.settings.ai_provider == "anthropic" and self.settings.anthropic_api_key: + api_key = self.settings.anthropic_api_key.get_secret_value() + elif self.settings.ai_provider == "openai" and self.settings.openai_api_key: + api_key = self.settings.openai_api_key.get_secret_value() + + self.ai_provider = create_ai_provider(self.settings.ai_provider, api_key) + + # Load cogs + await self._load_cogs() + + logger.info("GuardDen setup complete") + + async def _load_cogs(self) -> None: + """Load all cog extensions.""" + cogs = [ + "guardden.cogs.events", + "guardden.cogs.moderation", + "guardden.cogs.admin", + "guardden.cogs.automod", + "guardden.cogs.ai_moderation", + "guardden.cogs.verification", + ] + + for cog in cogs: + try: + await self.load_extension(cog) + logger.info(f"Loaded cog: {cog}") + except Exception as e: + logger.error(f"Failed to load cog {cog}: {e}") + + async def on_ready(self) -> None: + """Called when the bot is fully connected and ready.""" + if self.user: + logger.info(f"Logged in as {self.user} (ID: {self.user.id})") + logger.info(f"Connected to {len(self.guilds)} guild(s)") + + # Set presence + activity = discord.Activity( + type=discord.ActivityType.watching, + name="over your community", + ) + await self.change_presence(activity=activity) + + async def close(self) -> None: + """Clean up when shutting down.""" + logger.info("Shutting down GuardDen...") + if self.ai_provider: + try: + await self.ai_provider.close() + except Exception as e: + logger.error(f"Error closing AI provider: {e}") + await self.database.disconnect() + await super().close() + + async def on_guild_join(self, guild: discord.Guild) -> None: + """Called when the bot joins a new guild.""" + logger.info(f"Joined guild: {guild.name} (ID: {guild.id})") + + if self.guild_config: + await self.guild_config.create_guild(guild) + + async def on_guild_remove(self, guild: discord.Guild) -> None: + """Called when the bot is removed from a guild.""" + logger.info(f"Removed from guild: {guild.name} (ID: {guild.id})") diff --git a/src/guardden/cogs/__init__.py b/src/guardden/cogs/__init__.py new file mode 100644 index 0000000..53dae3e --- /dev/null +++ b/src/guardden/cogs/__init__.py @@ -0,0 +1 @@ +"""Discord cogs for GuardDen.""" diff --git a/src/guardden/cogs/admin.py b/src/guardden/cogs/admin.py new file mode 100644 index 0000000..743f626 --- /dev/null +++ b/src/guardden/cogs/admin.py @@ -0,0 +1,255 @@ +"""Admin commands for bot configuration.""" + +import logging +from typing import Literal + +import discord +from discord.ext import commands + +from guardden.bot import GuardDen + +logger = logging.getLogger(__name__) + + +class Admin(commands.Cog): + """Administrative commands for bot configuration.""" + + def __init__(self, bot: GuardDen) -> None: + self.bot = bot + + async def cog_check(self, ctx: commands.Context) -> bool: + """Ensure only administrators can use these commands.""" + if not ctx.guild: + return False + return ctx.author.guild_permissions.administrator + + @commands.group(name="config", invoke_without_command=True) + @commands.guild_only() + async def config(self, ctx: commands.Context) -> None: + """View or modify bot configuration.""" + config = await self.bot.guild_config.get_config(ctx.guild.id) + + if not config: + await ctx.send("No configuration found. Run a config command to initialize.") + return + + embed = discord.Embed( + title=f"Configuration for {ctx.guild.name}", + color=discord.Color.blue(), + ) + + # General settings + embed.add_field(name="Prefix", value=f"`{config.prefix}`", inline=True) + embed.add_field(name="Locale", value=config.locale, inline=True) + embed.add_field(name="\u200b", value="\u200b", inline=True) + + # Channels + log_ch = ctx.guild.get_channel(config.log_channel_id) if config.log_channel_id else None + mod_log_ch = ( + ctx.guild.get_channel(config.mod_log_channel_id) if config.mod_log_channel_id else None + ) + welcome_ch = ( + ctx.guild.get_channel(config.welcome_channel_id) if config.welcome_channel_id else None + ) + + embed.add_field( + name="Log Channel", value=log_ch.mention if log_ch else "Not set", inline=True + ) + embed.add_field( + name="Mod Log Channel", + value=mod_log_ch.mention if mod_log_ch else "Not set", + inline=True, + ) + embed.add_field( + name="Welcome Channel", + value=welcome_ch.mention if welcome_ch else "Not set", + inline=True, + ) + + # Features + features = [] + if config.automod_enabled: + features.append("AutoMod") + if config.anti_spam_enabled: + features.append("Anti-Spam") + if config.link_filter_enabled: + features.append("Link Filter") + if config.ai_moderation_enabled: + features.append("AI Moderation") + if config.verification_enabled: + features.append("Verification") + + embed.add_field( + name="Enabled Features", + value=", ".join(features) if features else "None", + inline=False, + ) + + await ctx.send(embed=embed) + + @config.command(name="prefix") + @commands.guild_only() + async def config_prefix(self, ctx: commands.Context, prefix: str) -> None: + """Set the command prefix for this server.""" + if not prefix or not prefix.strip(): + await ctx.send("Prefix cannot be empty or whitespace only.") + return + + if len(prefix) > 10: + await ctx.send("Prefix must be 10 characters or less.") + return + + await self.bot.guild_config.update_settings(ctx.guild.id, prefix=prefix) + await ctx.send(f"Command prefix set to `{prefix}`") + + @config.command(name="logchannel") + @commands.guild_only() + async def config_log_channel( + self, ctx: commands.Context, channel: discord.TextChannel | None = None + ) -> None: + """Set the channel for general event logs.""" + channel_id = channel.id if channel else None + await self.bot.guild_config.update_settings(ctx.guild.id, log_channel_id=channel_id) + + if channel: + await ctx.send(f"Log channel set to {channel.mention}") + else: + await ctx.send("Log channel has been disabled.") + + @config.command(name="modlogchannel") + @commands.guild_only() + async def config_mod_log_channel( + self, ctx: commands.Context, channel: discord.TextChannel | None = None + ) -> None: + """Set the channel for moderation action logs.""" + channel_id = channel.id if channel else None + await self.bot.guild_config.update_settings(ctx.guild.id, mod_log_channel_id=channel_id) + + if channel: + await ctx.send(f"Moderation log channel set to {channel.mention}") + else: + await ctx.send("Moderation log channel has been disabled.") + + @config.command(name="welcomechannel") + @commands.guild_only() + async def config_welcome_channel( + self, ctx: commands.Context, channel: discord.TextChannel | None = None + ) -> None: + """Set the welcome channel for new members.""" + channel_id = channel.id if channel else None + await self.bot.guild_config.update_settings(ctx.guild.id, welcome_channel_id=channel_id) + + if channel: + await ctx.send(f"Welcome channel set to {channel.mention}") + else: + await ctx.send("Welcome channel has been disabled.") + + @config.command(name="muterole") + @commands.guild_only() + async def config_mute_role( + self, ctx: commands.Context, role: discord.Role | None = None + ) -> None: + """Set the role to assign when muting members.""" + role_id = role.id if role else None + await self.bot.guild_config.update_settings(ctx.guild.id, mute_role_id=role_id) + + if role: + await ctx.send(f"Mute role set to {role.mention}") + else: + await ctx.send("Mute role has been cleared.") + + @config.command(name="automod") + @commands.guild_only() + async def config_automod(self, ctx: commands.Context, enabled: bool) -> None: + """Enable or disable automod features.""" + await self.bot.guild_config.update_settings(ctx.guild.id, automod_enabled=enabled) + status = "enabled" if enabled else "disabled" + await ctx.send(f"AutoMod has been {status}.") + + @config.command(name="antispam") + @commands.guild_only() + async def config_antispam(self, ctx: commands.Context, enabled: bool) -> None: + """Enable or disable anti-spam protection.""" + await self.bot.guild_config.update_settings(ctx.guild.id, anti_spam_enabled=enabled) + status = "enabled" if enabled else "disabled" + await ctx.send(f"Anti-spam has been {status}.") + + @config.command(name="linkfilter") + @commands.guild_only() + async def config_linkfilter(self, ctx: commands.Context, enabled: bool) -> None: + """Enable or disable link filtering.""" + await self.bot.guild_config.update_settings(ctx.guild.id, link_filter_enabled=enabled) + status = "enabled" if enabled else "disabled" + await ctx.send(f"Link filter has been {status}.") + + @commands.group(name="bannedwords", aliases=["bw"], invoke_without_command=True) + @commands.guild_only() + async def banned_words(self, ctx: commands.Context) -> None: + """Manage banned words list.""" + words = await self.bot.guild_config.get_banned_words(ctx.guild.id) + + if not words: + await ctx.send("No banned words configured.") + return + + embed = discord.Embed( + title="Banned Words", + color=discord.Color.red(), + ) + + for word in words[:25]: # Discord embed limit + word_type = "Regex" if word.is_regex else "Text" + embed.add_field( + name=f"#{word.id}: {word.pattern[:30]}", + value=f"Type: {word_type} | Action: {word.action}", + inline=True, + ) + + if len(words) > 25: + embed.set_footer(text=f"Showing 25 of {len(words)} banned words") + + await ctx.send(embed=embed) + + @banned_words.command(name="add") + @commands.guild_only() + async def banned_words_add( + self, + ctx: commands.Context, + pattern: str, + action: Literal["delete", "warn", "strike"] = "delete", + is_regex: bool = False, + ) -> None: + """Add a banned word or pattern.""" + word = await self.bot.guild_config.add_banned_word( + guild_id=ctx.guild.id, + pattern=pattern, + added_by=ctx.author.id, + is_regex=is_regex, + action=action, + ) + + word_type = "regex pattern" if is_regex else "word" + await ctx.send(f"Added banned {word_type}: `{pattern}` (ID: {word.id}, Action: {action})") + + @banned_words.command(name="remove", aliases=["delete"]) + @commands.guild_only() + async def banned_words_remove(self, ctx: commands.Context, word_id: int) -> None: + """Remove a banned word by ID.""" + success = await self.bot.guild_config.remove_banned_word(ctx.guild.id, word_id) + + if success: + await ctx.send(f"Removed banned word #{word_id}") + else: + await ctx.send(f"Banned word #{word_id} not found.") + + @commands.command(name="sync") + @commands.is_owner() + async def sync_commands(self, ctx: commands.Context) -> None: + """Sync slash commands (bot owner only).""" + await self.bot.tree.sync() + await ctx.send("Slash commands synced.") + + +async def setup(bot: GuardDen) -> None: + """Load the Admin cog.""" + await bot.add_cog(Admin(bot)) diff --git a/src/guardden/cogs/ai_moderation.py b/src/guardden/cogs/ai_moderation.py new file mode 100644 index 0000000..54a5fe5 --- /dev/null +++ b/src/guardden/cogs/ai_moderation.py @@ -0,0 +1,366 @@ +"""AI-powered moderation cog.""" + +import logging +import re +from collections import deque +from datetime import datetime, timedelta, timezone + +import discord +from discord.ext import commands + +from guardden.bot import GuardDen +from guardden.services.ai.base import ContentCategory, ModerationResult + +logger = logging.getLogger(__name__) + +# URL pattern for extraction +URL_PATTERN = re.compile( + r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+[^\s]*", + re.IGNORECASE, +) + + +class AIModeration(commands.Cog): + """AI-powered content moderation.""" + + def __init__(self, bot: GuardDen) -> None: + self.bot = bot + # Track recently analyzed messages to avoid duplicates (deque auto-removes oldest) + self._analyzed_messages: deque[int] = deque(maxlen=1000) + + def _should_analyze(self, message: discord.Message) -> bool: + """Determine if a message should be analyzed by AI.""" + # Skip if already analyzed + if message.id in self._analyzed_messages: + return False + + # Skip short messages + if len(message.content) < 20 and not message.attachments: + return False + + # Skip messages from bots + if message.author.bot: + return False + + return True + + def _track_message(self, message_id: int) -> None: + """Track that a message has been analyzed.""" + self._analyzed_messages.append(message_id) + + async def _handle_ai_result( + self, + message: discord.Message, + result: ModerationResult, + analysis_type: str, + ) -> None: + """Handle the result of AI analysis.""" + if not result.is_flagged: + return + + config = await self.bot.guild_config.get_config(message.guild.id) + if not config: + return + + # Check if severity meets threshold based on sensitivity + # Higher sensitivity = lower threshold needed to trigger + threshold = 100 - config.ai_sensitivity # e.g., sensitivity 70 = threshold 30 + if result.severity < threshold: + logger.debug( + f"AI flagged content but below threshold: " + f"severity={result.severity}, threshold={threshold}" + ) + return + + # Determine action based on suggested action and severity + should_delete = result.suggested_action in ("delete", "timeout", "ban") + should_timeout = result.suggested_action in ("timeout", "ban") and result.severity > 70 + + # Delete message if needed + if should_delete: + try: + await message.delete() + except discord.Forbidden: + logger.warning(f"Cannot delete message: missing permissions") + except discord.NotFound: + pass + + # Timeout user for severe violations + if should_timeout and isinstance(message.author, discord.Member): + timeout_duration = 300 if result.severity < 90 else 3600 # 5 min or 1 hour + try: + await message.author.timeout( + timedelta(seconds=timeout_duration), + reason=f"AI Moderation: {result.explanation[:100]}", + ) + except discord.Forbidden: + pass + + # Log to mod channel + await self._log_ai_action(message, result, analysis_type) + + # Notify user + try: + embed = discord.Embed( + title=f"Message Flagged in {message.guild.name}", + description=result.explanation, + color=discord.Color.red(), + timestamp=datetime.now(timezone.utc), + ) + embed.add_field( + name="Categories", + value=", ".join(cat.value for cat in result.categories) or "Unknown", + ) + if should_timeout: + embed.add_field(name="Action", value="You have been timed out") + await message.author.send(embed=embed) + except discord.Forbidden: + pass + + async def _log_ai_action( + self, + message: discord.Message, + result: ModerationResult, + analysis_type: str, + ) -> None: + """Log an AI moderation action.""" + config = await self.bot.guild_config.get_config(message.guild.id) + if not config or not config.mod_log_channel_id: + return + + channel = message.guild.get_channel(config.mod_log_channel_id) + if not channel or not isinstance(channel, discord.TextChannel): + return + + embed = discord.Embed( + title=f"AI Moderation - {analysis_type}", + color=discord.Color.red(), + timestamp=datetime.now(timezone.utc), + ) + embed.set_author( + name=str(message.author), + icon_url=message.author.display_avatar.url, + ) + + embed.add_field(name="Confidence", value=f"{result.confidence:.0%}", inline=True) + embed.add_field(name="Severity", value=f"{result.severity}/100", inline=True) + embed.add_field(name="Action", value=result.suggested_action, inline=True) + + categories = ", ".join(cat.value for cat in result.categories) + embed.add_field(name="Categories", value=categories or "None", inline=False) + embed.add_field(name="Explanation", value=result.explanation[:500], inline=False) + + if message.content: + content = ( + message.content[:500] + "..." if len(message.content) > 500 else message.content + ) + embed.add_field(name="Content", value=f"```{content}```", inline=False) + + embed.set_footer(text=f"User ID: {message.author.id} | Channel: #{message.channel.name}") + + await channel.send(embed=embed) + + @commands.Cog.listener() + async def on_message(self, message: discord.Message) -> None: + """Analyze messages with AI moderation.""" + if not message.guild: + return + + # Check if AI moderation is enabled for this guild + config = await self.bot.guild_config.get_config(message.guild.id) + if not config or not config.ai_moderation_enabled: + return + + # Skip users with manage_messages permission + if isinstance(message.author, discord.Member): + if message.author.guild_permissions.manage_messages: + return + + if not self._should_analyze(message): + return + + self._track_message(message.id) + + # Analyze text content + if message.content and len(message.content) >= 20: + result = await self.bot.ai_provider.moderate_text( + content=message.content, + context=f"Discord server: {message.guild.name}, channel: {message.channel.name}", + sensitivity=config.ai_sensitivity, + ) + + if result.is_flagged: + await self._handle_ai_result(message, result, "Text Analysis") + return # Don't continue if already flagged + + # Analyze images if NSFW detection is enabled (limit to 3 per message) + if config.nsfw_detection_enabled and message.attachments: + images_analyzed = 0 + for attachment in message.attachments: + if images_analyzed >= 3: + break + if attachment.content_type and attachment.content_type.startswith("image/"): + images_analyzed += 1 + image_result = await self.bot.ai_provider.analyze_image( + image_url=attachment.url, + sensitivity=config.ai_sensitivity, + ) + + if ( + image_result.is_nsfw + or image_result.is_violent + or image_result.is_disturbing + ): + # Convert to ModerationResult format + categories = [] + if image_result.is_nsfw: + categories.append(ContentCategory.SEXUAL) + if image_result.is_violent: + categories.append(ContentCategory.VIOLENCE) + + result = ModerationResult( + is_flagged=True, + confidence=image_result.confidence, + categories=categories, + explanation=image_result.description, + suggested_action="delete", + ) + await self._handle_ai_result(message, result, "Image Analysis") + return + + # Analyze URLs for phishing + urls = URL_PATTERN.findall(message.content) + for url in urls[:3]: # Limit to first 3 URLs + phishing_result = await self.bot.ai_provider.analyze_phishing( + url=url, + message_content=message.content, + ) + + if phishing_result.is_phishing and phishing_result.confidence > 0.7: + result = ModerationResult( + is_flagged=True, + confidence=phishing_result.confidence, + categories=[ContentCategory.SCAM], + explanation=phishing_result.explanation, + suggested_action="delete", + ) + await self._handle_ai_result(message, result, "Phishing Detection") + return + + @commands.group(name="ai", invoke_without_command=True) + @commands.has_permissions(administrator=True) + @commands.guild_only() + async def ai_cmd(self, ctx: commands.Context) -> None: + """View AI moderation settings.""" + config = await self.bot.guild_config.get_config(ctx.guild.id) + + embed = discord.Embed( + title="AI Moderation Settings", + color=discord.Color.blue(), + ) + + embed.add_field( + name="AI Moderation", + value="✅ Enabled" if config and config.ai_moderation_enabled else "❌ Disabled", + inline=True, + ) + embed.add_field( + name="NSFW Detection", + value="✅ Enabled" if config and config.nsfw_detection_enabled else "❌ Disabled", + inline=True, + ) + embed.add_field( + name="Sensitivity", + value=f"{config.ai_sensitivity}/100" if config else "50/100", + inline=True, + ) + embed.add_field( + name="AI Provider", + value=self.bot.settings.ai_provider.capitalize(), + inline=True, + ) + + await ctx.send(embed=embed) + + @ai_cmd.command(name="enable") + @commands.has_permissions(administrator=True) + @commands.guild_only() + async def ai_enable(self, ctx: commands.Context) -> None: + """Enable AI moderation.""" + if self.bot.settings.ai_provider == "none": + await ctx.send( + "AI moderation is not configured. Set `GUARDDEN_AI_PROVIDER` and API key." + ) + return + + await self.bot.guild_config.update_settings(ctx.guild.id, ai_moderation_enabled=True) + await ctx.send("✅ AI moderation enabled.") + + @ai_cmd.command(name="disable") + @commands.has_permissions(administrator=True) + @commands.guild_only() + async def ai_disable(self, ctx: commands.Context) -> None: + """Disable AI moderation.""" + await self.bot.guild_config.update_settings(ctx.guild.id, ai_moderation_enabled=False) + await ctx.send("❌ AI moderation disabled.") + + @ai_cmd.command(name="sensitivity") + @commands.has_permissions(administrator=True) + @commands.guild_only() + async def ai_sensitivity(self, ctx: commands.Context, level: int) -> None: + """Set AI sensitivity level (0-100). Higher = more strict.""" + if not 0 <= level <= 100: + await ctx.send("Sensitivity must be between 0 and 100.") + return + + await self.bot.guild_config.update_settings(ctx.guild.id, ai_sensitivity=level) + await ctx.send(f"AI sensitivity set to {level}/100.") + + @ai_cmd.command(name="nsfw") + @commands.has_permissions(administrator=True) + @commands.guild_only() + async def ai_nsfw(self, ctx: commands.Context, enabled: bool) -> None: + """Enable or disable NSFW image detection.""" + await self.bot.guild_config.update_settings(ctx.guild.id, nsfw_detection_enabled=enabled) + status = "enabled" if enabled else "disabled" + await ctx.send(f"NSFW detection {status}.") + + @ai_cmd.command(name="analyze") + @commands.has_permissions(administrator=True) + @commands.guild_only() + async def ai_analyze(self, ctx: commands.Context, *, text: str) -> None: + """Test AI analysis on text (does not take action).""" + if self.bot.settings.ai_provider == "none": + await ctx.send("AI moderation is not configured.") + return + + async with ctx.typing(): + result = await self.bot.ai_provider.moderate_text( + content=text, + context=f"Test analysis in {ctx.guild.name}", + sensitivity=50, + ) + + embed = discord.Embed( + title="AI Analysis Result", + color=discord.Color.red() if result.is_flagged else discord.Color.green(), + ) + + embed.add_field(name="Flagged", value="Yes" if result.is_flagged else "No", inline=True) + embed.add_field(name="Confidence", value=f"{result.confidence:.0%}", inline=True) + embed.add_field(name="Severity", value=f"{result.severity}/100", inline=True) + embed.add_field(name="Suggested Action", value=result.suggested_action, inline=True) + + if result.categories: + categories = ", ".join(cat.value for cat in result.categories) + embed.add_field(name="Categories", value=categories, inline=False) + + if result.explanation: + embed.add_field(name="Explanation", value=result.explanation[:1000], inline=False) + + await ctx.send(embed=embed) + + +async def setup(bot: GuardDen) -> None: + """Load the AI Moderation cog.""" + await bot.add_cog(AIModeration(bot)) diff --git a/src/guardden/cogs/automod.py b/src/guardden/cogs/automod.py new file mode 100644 index 0000000..a98a742 --- /dev/null +++ b/src/guardden/cogs/automod.py @@ -0,0 +1,267 @@ +"""Automod cog for automatic content moderation.""" + +import logging +from datetime import datetime, timedelta, timezone + +import discord +from discord.ext import commands + +from guardden.bot import GuardDen +from guardden.services.automod import AutomodResult, AutomodService + +logger = logging.getLogger(__name__) + + +class Automod(commands.Cog): + """Automatic content moderation.""" + + def __init__(self, bot: GuardDen) -> None: + self.bot = bot + self.automod = AutomodService() + + async def _handle_violation( + self, + message: discord.Message, + result: AutomodResult, + ) -> None: + """Handle an automod violation.""" + # Delete the message + if result.should_delete: + try: + await message.delete() + except discord.Forbidden: + logger.warning(f"Cannot delete message in {message.guild}: missing permissions") + except discord.NotFound: + pass # Already deleted + + # Apply timeout + if result.should_timeout and result.timeout_duration > 0: + try: + await message.author.timeout( + timedelta(seconds=result.timeout_duration), + reason=f"Automod: {result.reason}", + ) + except discord.Forbidden: + logger.warning(f"Cannot timeout {message.author}: missing permissions") + + # Log the action + await self._log_automod_action(message, result) + + # Notify the user via DM + try: + embed = discord.Embed( + title=f"Message Removed in {message.guild.name}", + description=result.reason, + color=discord.Color.orange(), + timestamp=datetime.now(timezone.utc), + ) + if result.should_timeout: + embed.add_field( + name="Timeout", + value=f"You have been timed out for {result.timeout_duration} seconds.", + ) + await message.author.send(embed=embed) + except discord.Forbidden: + pass # User has DMs disabled + + async def _log_automod_action( + self, + message: discord.Message, + result: AutomodResult, + ) -> None: + """Log an automod action to the mod log channel.""" + config = await self.bot.guild_config.get_config(message.guild.id) + if not config or not config.mod_log_channel_id: + return + + channel = message.guild.get_channel(config.mod_log_channel_id) + if not channel or not isinstance(channel, discord.TextChannel): + return + + embed = discord.Embed( + title="Automod Action", + color=discord.Color.orange(), + timestamp=datetime.now(timezone.utc), + ) + embed.set_author( + name=str(message.author), + icon_url=message.author.display_avatar.url, + ) + embed.add_field(name="Filter", value=result.matched_filter, inline=True) + embed.add_field(name="Channel", value=message.channel.mention, inline=True) + embed.add_field(name="Reason", value=result.reason, inline=False) + + if message.content: + content = ( + message.content[:500] + "..." if len(message.content) > 500 else message.content + ) + embed.add_field(name="Message Content", value=f"```{content}```", inline=False) + + actions = [] + if result.should_delete: + actions.append("Message deleted") + if result.should_warn: + actions.append("User warned") + if result.should_strike: + actions.append("Strike added") + if result.should_timeout: + actions.append(f"Timeout ({result.timeout_duration}s)") + + embed.add_field(name="Actions Taken", value=", ".join(actions) or "None", inline=False) + embed.set_footer(text=f"User ID: {message.author.id}") + + await channel.send(embed=embed) + + @commands.Cog.listener() + async def on_message(self, message: discord.Message) -> None: + """Check all messages for automod violations.""" + # Ignore DMs, bots, and empty messages + if not message.guild or message.author.bot or not message.content: + return + + # Ignore users with manage_messages permission + if isinstance(message.author, discord.Member): + if message.author.guild_permissions.manage_messages: + return + + # Get guild config + config = await self.bot.guild_config.get_config(message.guild.id) + if not config or not config.automod_enabled: + return + + result: AutomodResult | None = None + + # Check banned words + banned_words = await self.bot.guild_config.get_banned_words(message.guild.id) + if banned_words: + result = self.automod.check_banned_words(message.content, banned_words) + + # Check scam links (if link filter enabled) + if not result and config.link_filter_enabled: + result = self.automod.check_scam_links(message.content) + + # Check spam + if not result and config.anti_spam_enabled: + result = self.automod.check_spam(message, anti_spam_enabled=True) + + # Check invite links (if link filter enabled) + if not result and config.link_filter_enabled: + result = self.automod.check_invite_links(message.content, allow_invites=False) + + # Handle violation if found + if result: + logger.info( + f"Automod triggered in {message.guild.name}: " + f"{result.matched_filter} by {message.author}" + ) + await self._handle_violation(message, result) + + @commands.Cog.listener() + async def on_message_edit(self, before: discord.Message, after: discord.Message) -> None: + """Check edited messages for automod violations.""" + # Only check if content changed + if before.content == after.content: + return + + # Reuse on_message logic + await self.on_message(after) + + @commands.group(name="automod", invoke_without_command=True) + @commands.has_permissions(administrator=True) + @commands.guild_only() + async def automod_cmd(self, ctx: commands.Context) -> None: + """View automod status and configuration.""" + config = await self.bot.guild_config.get_config(ctx.guild.id) + + embed = discord.Embed( + title="Automod Configuration", + color=discord.Color.blue(), + ) + + embed.add_field( + name="Automod Enabled", + value="✅ Yes" if config and config.automod_enabled else "❌ No", + inline=True, + ) + embed.add_field( + name="Anti-Spam", + value="✅ Yes" if config and config.anti_spam_enabled else "❌ No", + inline=True, + ) + embed.add_field( + name="Link Filter", + value="✅ Yes" if config and config.link_filter_enabled else "❌ No", + inline=True, + ) + + # Show thresholds + embed.add_field( + name="Rate Limit", + value=f"{self.automod.message_rate_limit} msgs / {self.automod.message_rate_window}s", + inline=True, + ) + embed.add_field( + name="Duplicate Threshold", + value=f"{self.automod.duplicate_threshold} same messages", + inline=True, + ) + embed.add_field( + name="Mention Limit", + value=f"{self.automod.mention_limit} per message", + inline=True, + ) + + banned_words = await self.bot.guild_config.get_banned_words(ctx.guild.id) + embed.add_field( + name="Banned Words", + value=f"{len(banned_words)} configured", + inline=True, + ) + + await ctx.send(embed=embed) + + @automod_cmd.command(name="test") + @commands.has_permissions(administrator=True) + @commands.guild_only() + async def automod_test(self, ctx: commands.Context, *, text: str) -> None: + """Test a message against automod filters (does not take action).""" + config = await self.bot.guild_config.get_config(ctx.guild.id) + results = [] + + # Check banned words + banned_words = await self.bot.guild_config.get_banned_words(ctx.guild.id) + result = self.automod.check_banned_words(text, banned_words) + if result: + results.append(f"**Banned Words**: {result.reason}") + + # Check scam links + result = self.automod.check_scam_links(text) + if result: + results.append(f"**Scam Detection**: {result.reason}") + + # Check invite links + result = self.automod.check_invite_links(text, allow_invites=False) + if result: + results.append(f"**Invite Links**: {result.reason}") + + # Check caps + result = self.automod.check_all_caps(text) + if result: + results.append(f"**Excessive Caps**: {result.reason}") + + embed = discord.Embed( + title="Automod Test Results", + color=discord.Color.red() if results else discord.Color.green(), + ) + + if results: + embed.description = "\n".join(results) + else: + embed.description = "✅ No violations detected" + + await ctx.send(embed=embed) + + +async def setup(bot: GuardDen) -> None: + """Load the Automod cog.""" + await bot.add_cog(Automod(bot)) diff --git a/src/guardden/cogs/events.py b/src/guardden/cogs/events.py new file mode 100644 index 0000000..0d88aaf --- /dev/null +++ b/src/guardden/cogs/events.py @@ -0,0 +1,237 @@ +"""Event handlers for logging and monitoring.""" + +import logging +from datetime import datetime, timezone + +import discord +from discord.ext import commands + +from guardden.bot import GuardDen + +logger = logging.getLogger(__name__) + + +class Events(commands.Cog): + """Handles Discord events for logging and monitoring.""" + + def __init__(self, bot: GuardDen) -> None: + self.bot = bot + + @commands.Cog.listener() + async def on_member_join(self, member: discord.Member) -> None: + """Called when a member joins a guild.""" + logger.debug(f"Member joined: {member} in {member.guild}") + + config = await self.bot.guild_config.get_config(member.guild.id) + if not config or not config.log_channel_id: + return + + channel = member.guild.get_channel(config.log_channel_id) + if not channel or not isinstance(channel, discord.TextChannel): + return + + embed = discord.Embed( + title="Member Joined", + description=f"{member.mention} ({member})", + color=discord.Color.green(), + timestamp=datetime.now(timezone.utc), + ) + embed.set_thumbnail(url=member.display_avatar.url) + embed.add_field( + name="Account Created", value=discord.utils.format_dt(member.created_at, "R") + ) + embed.add_field(name="Member ID", value=str(member.id)) + + await channel.send(embed=embed) + + @commands.Cog.listener() + async def on_member_remove(self, member: discord.Member) -> None: + """Called when a member leaves a guild.""" + logger.debug(f"Member left: {member} from {member.guild}") + + config = await self.bot.guild_config.get_config(member.guild.id) + if not config or not config.log_channel_id: + return + + channel = member.guild.get_channel(config.log_channel_id) + if not channel or not isinstance(channel, discord.TextChannel): + return + + embed = discord.Embed( + title="Member Left", + description=f"{member} ({member.id})", + color=discord.Color.orange(), + timestamp=datetime.now(timezone.utc), + ) + embed.set_thumbnail(url=member.display_avatar.url) + + if member.joined_at: + embed.add_field(name="Joined", value=discord.utils.format_dt(member.joined_at, "R")) + + roles = [r.mention for r in member.roles if r != member.guild.default_role] + if roles: + embed.add_field(name="Roles", value=", ".join(roles[:10]), inline=False) + + await channel.send(embed=embed) + + @commands.Cog.listener() + async def on_message_delete(self, message: discord.Message) -> None: + """Called when a message is deleted.""" + if message.author.bot or not message.guild: + return + + config = await self.bot.guild_config.get_config(message.guild.id) + if not config or not config.log_channel_id: + return + + channel = message.guild.get_channel(config.log_channel_id) + if not channel or not isinstance(channel, discord.TextChannel): + return + + embed = discord.Embed( + title="Message Deleted", + description=f"In {message.channel.mention}", + color=discord.Color.red(), + timestamp=datetime.now(timezone.utc), + ) + embed.set_author(name=str(message.author), icon_url=message.author.display_avatar.url) + + if message.content: + content = message.content[:1024] if len(message.content) > 1024 else message.content + embed.add_field(name="Content", value=content, inline=False) + + if message.attachments: + attachments = "\n".join(a.filename for a in message.attachments) + embed.add_field(name="Attachments", value=attachments, inline=False) + + embed.set_footer(text=f"Author ID: {message.author.id} | Message ID: {message.id}") + + await channel.send(embed=embed) + + @commands.Cog.listener() + async def on_message_edit(self, before: discord.Message, after: discord.Message) -> None: + """Called when a message is edited.""" + if before.author.bot or not before.guild: + return + + if before.content == after.content: + return + + config = await self.bot.guild_config.get_config(before.guild.id) + if not config or not config.log_channel_id: + return + + channel = before.guild.get_channel(config.log_channel_id) + if not channel or not isinstance(channel, discord.TextChannel): + return + + embed = discord.Embed( + title="Message Edited", + description=f"In {before.channel.mention} | [Jump to message]({after.jump_url})", + color=discord.Color.blue(), + timestamp=datetime.now(timezone.utc), + ) + embed.set_author(name=str(before.author), icon_url=before.author.display_avatar.url) + + before_content = before.content[:1024] if len(before.content) > 1024 else before.content + after_content = after.content[:1024] if len(after.content) > 1024 else after.content + + embed.add_field(name="Before", value=before_content or "*empty*", inline=False) + embed.add_field(name="After", value=after_content or "*empty*", inline=False) + embed.set_footer(text=f"Author ID: {before.author.id}") + + await channel.send(embed=embed) + + @commands.Cog.listener() + async def on_voice_state_update( + self, + member: discord.Member, + before: discord.VoiceState, + after: discord.VoiceState, + ) -> None: + """Called when a member's voice state changes.""" + if member.bot: + return + + config = await self.bot.guild_config.get_config(member.guild.id) + if not config or not config.log_channel_id: + return + + channel = member.guild.get_channel(config.log_channel_id) + if not channel or not isinstance(channel, discord.TextChannel): + return + + embed = None + + if before.channel is None and after.channel is not None: + embed = discord.Embed( + title="Voice Channel Joined", + description=f"{member.mention} joined {after.channel.mention}", + color=discord.Color.green(), + timestamp=datetime.now(timezone.utc), + ) + elif before.channel is not None and after.channel is None: + embed = discord.Embed( + title="Voice Channel Left", + description=f"{member.mention} left {before.channel.mention}", + color=discord.Color.orange(), + timestamp=datetime.now(timezone.utc), + ) + elif before.channel != after.channel and before.channel and after.channel: + embed = discord.Embed( + title="Voice Channel Moved", + description=f"{member.mention} moved from {before.channel.mention} to {after.channel.mention}", + color=discord.Color.blue(), + timestamp=datetime.now(timezone.utc), + ) + + if embed: + embed.set_author(name=str(member), icon_url=member.display_avatar.url) + await channel.send(embed=embed) + + @commands.Cog.listener() + async def on_member_ban(self, guild: discord.Guild, user: discord.User) -> None: + """Called when a user is banned.""" + config = await self.bot.guild_config.get_config(guild.id) + if not config or not config.mod_log_channel_id: + return + + channel = guild.get_channel(config.mod_log_channel_id) + if not channel or not isinstance(channel, discord.TextChannel): + return + + embed = discord.Embed( + title="Member Banned", + description=f"{user} ({user.id})", + color=discord.Color.dark_red(), + timestamp=datetime.now(timezone.utc), + ) + embed.set_thumbnail(url=user.display_avatar.url) + + await channel.send(embed=embed) + + @commands.Cog.listener() + async def on_member_unban(self, guild: discord.Guild, user: discord.User) -> None: + """Called when a user is unbanned.""" + config = await self.bot.guild_config.get_config(guild.id) + if not config or not config.mod_log_channel_id: + return + + channel = guild.get_channel(config.mod_log_channel_id) + if not channel or not isinstance(channel, discord.TextChannel): + return + + embed = discord.Embed( + title="Member Unbanned", + description=f"{user} ({user.id})", + color=discord.Color.green(), + timestamp=datetime.now(timezone.utc), + ) + embed.set_thumbnail(url=user.display_avatar.url) + + await channel.send(embed=embed) + + +async def setup(bot: GuardDen) -> None: + """Load the Events cog.""" + await bot.add_cog(Events(bot)) diff --git a/src/guardden/cogs/moderation.py b/src/guardden/cogs/moderation.py new file mode 100644 index 0000000..571a324 --- /dev/null +++ b/src/guardden/cogs/moderation.py @@ -0,0 +1,466 @@ +"""Moderation commands and automod features.""" + +import logging +import re +from datetime import datetime, timedelta, timezone + +import discord +from discord.ext import commands +from sqlalchemy import func, select + +from guardden.bot import GuardDen +from guardden.models import ModerationLog, Strike + +logger = logging.getLogger(__name__) + + +def parse_duration(duration_str: str) -> timedelta | None: + """Parse a duration string like '1h', '30m', '7d' into a timedelta.""" + match = re.match(r"^(\d+)([smhdw])$", duration_str.lower()) + if not match: + return None + + amount = int(match.group(1)) + unit = match.group(2) + + units = { + "s": timedelta(seconds=amount), + "m": timedelta(minutes=amount), + "h": timedelta(hours=amount), + "d": timedelta(days=amount), + "w": timedelta(weeks=amount), + } + + return units.get(unit) + + +class Moderation(commands.Cog): + """Moderation commands for server management.""" + + def __init__(self, bot: GuardDen) -> None: + self.bot = bot + + async def _log_action( + self, + guild: discord.Guild, + target: discord.Member | discord.User, + moderator: discord.Member | discord.User, + action: str, + reason: str | None = None, + duration: int | None = None, + channel: discord.TextChannel | None = None, + message: discord.Message | None = None, + is_automatic: bool = False, + ) -> None: + """Log a moderation action to the database.""" + expires_at = None + if duration: + expires_at = datetime.now(timezone.utc) + timedelta(seconds=duration) + + async with self.bot.database.session() as session: + log_entry = ModerationLog( + guild_id=guild.id, + target_id=target.id, + target_name=str(target), + moderator_id=moderator.id, + moderator_name=str(moderator), + action=action, + reason=reason, + duration=duration, + expires_at=expires_at, + channel_id=channel.id if channel else None, + message_id=message.id if message else None, + message_content=message.content if message else None, + is_automatic=is_automatic, + ) + session.add(log_entry) + + async def _get_strike_count(self, guild_id: int, user_id: int) -> int: + """Get the total active strike count for a user.""" + async with self.bot.database.session() as session: + result = await session.execute( + select(func.sum(Strike.points)).where( + Strike.guild_id == guild_id, + Strike.user_id == user_id, + Strike.is_active == True, + ) + ) + total = result.scalar() + return total or 0 + + async def _add_strike( + self, + guild: discord.Guild, + user: discord.Member, + moderator: discord.Member | discord.User, + reason: str, + points: int = 1, + ) -> int: + """Add a strike to a user and return their new total.""" + async with self.bot.database.session() as session: + strike = Strike( + guild_id=guild.id, + user_id=user.id, + user_name=str(user), + moderator_id=moderator.id, + reason=reason, + points=points, + ) + session.add(strike) + + return await self._get_strike_count(guild.id, user.id) + + @commands.command(name="warn") + @commands.has_permissions(kick_members=True) + @commands.guild_only() + async def warn( + self, ctx: commands.Context, member: discord.Member, *, reason: str = "No reason provided" + ) -> None: + """Warn a member.""" + if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner: + await ctx.send("You cannot warn someone with a higher or equal role.") + return + + await self._log_action(ctx.guild, member, ctx.author, "warn", reason) + + embed = discord.Embed( + title="Warning Issued", + description=f"{member.mention} has been warned.", + color=discord.Color.yellow(), + timestamp=datetime.now(timezone.utc), + ) + embed.add_field(name="Reason", value=reason, inline=False) + embed.set_footer(text=f"Moderator: {ctx.author}") + + await ctx.send(embed=embed) + + # Try to DM the user + try: + dm_embed = discord.Embed( + title=f"Warning in {ctx.guild.name}", + description=f"You have been warned.", + color=discord.Color.yellow(), + ) + dm_embed.add_field(name="Reason", value=reason) + await member.send(embed=dm_embed) + except discord.Forbidden: + pass + + @commands.command(name="strike") + @commands.has_permissions(kick_members=True) + @commands.guild_only() + async def strike( + self, + ctx: commands.Context, + member: discord.Member, + points: int = 1, + *, + reason: str = "No reason provided", + ) -> None: + """Add a strike to a member.""" + if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner: + await ctx.send("You cannot strike someone with a higher or equal role.") + return + + total_strikes = await self._add_strike(ctx.guild, member, ctx.author, reason, points) + await self._log_action(ctx.guild, member, ctx.author, "strike", reason) + + embed = discord.Embed( + title="Strike Added", + description=f"{member.mention} has received {points} strike(s).", + color=discord.Color.orange(), + timestamp=datetime.now(timezone.utc), + ) + embed.add_field(name="Reason", value=reason, inline=False) + embed.add_field(name="Total Strikes", value=str(total_strikes)) + embed.set_footer(text=f"Moderator: {ctx.author}") + + await ctx.send(embed=embed) + + # Check for automatic actions based on strike thresholds + config = await self.bot.guild_config.get_config(ctx.guild.id) + if config and config.strike_actions: + for threshold, action_config in sorted( + config.strike_actions.items(), key=lambda x: int(x[0]), reverse=True + ): + if total_strikes >= int(threshold): + action = action_config.get("action") + if action == "ban": + await ctx.invoke( + self.ban, member=member, reason=f"Automatic: {total_strikes} strikes" + ) + elif action == "kick": + await ctx.invoke( + self.kick, member=member, reason=f"Automatic: {total_strikes} strikes" + ) + elif action == "timeout": + duration = action_config.get("duration", 3600) + await ctx.invoke( + self.timeout, + member=member, + duration=f"{duration}s", + reason=f"Automatic: {total_strikes} strikes", + ) + break + + @commands.command(name="strikes") + @commands.has_permissions(kick_members=True) + @commands.guild_only() + async def strikes(self, ctx: commands.Context, member: discord.Member) -> None: + """View strikes for a member.""" + async with self.bot.database.session() as session: + result = await session.execute( + select(Strike) + .where( + Strike.guild_id == ctx.guild.id, + Strike.user_id == member.id, + Strike.is_active == True, + ) + .order_by(Strike.created_at.desc()) + .limit(10) + ) + user_strikes = result.scalars().all() + + total = await self._get_strike_count(ctx.guild.id, member.id) + + embed = discord.Embed( + title=f"Strikes for {member}", + description=f"Total active strikes: **{total}**", + color=discord.Color.orange(), + ) + + if user_strikes: + for strike in user_strikes: + embed.add_field( + name=f"Strike #{strike.id} ({strike.points} pts)", + value=f"{strike.reason}\n*{strike.created_at.strftime('%Y-%m-%d')}*", + inline=False, + ) + else: + embed.description = f"{member.mention} has no active strikes." + + await ctx.send(embed=embed) + + @commands.command(name="timeout", aliases=["mute"]) + @commands.has_permissions(moderate_members=True) + @commands.guild_only() + async def timeout( + self, + ctx: commands.Context, + member: discord.Member, + duration: str = "1h", + *, + reason: str = "No reason provided", + ) -> None: + """Timeout a member (e.g., !timeout @user 1h Spamming).""" + if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner: + await ctx.send("You cannot timeout someone with a higher or equal role.") + return + + delta = parse_duration(duration) + if not delta: + await ctx.send("Invalid duration. Use format like: 30m, 1h, 7d") + return + + if delta > timedelta(days=28): + await ctx.send("Timeout duration cannot exceed 28 days.") + return + + try: + await member.timeout(delta, reason=f"{ctx.author}: {reason}") + except discord.Forbidden: + await ctx.send("I don't have permission to timeout this user.") + return + except discord.HTTPException as e: + await ctx.send(f"Failed to timeout user: {e}") + return + + await self._log_action( + ctx.guild, member, ctx.author, "timeout", reason, int(delta.total_seconds()) + ) + + embed = discord.Embed( + title="Member Timed Out", + description=f"{member.mention} has been timed out for {duration}.", + color=discord.Color.orange(), + timestamp=datetime.now(timezone.utc), + ) + embed.add_field(name="Reason", value=reason, inline=False) + embed.set_footer(text=f"Moderator: {ctx.author}") + + await ctx.send(embed=embed) + + @commands.command(name="untimeout", aliases=["unmute"]) + @commands.has_permissions(moderate_members=True) + @commands.guild_only() + async def untimeout( + self, ctx: commands.Context, member: discord.Member, *, reason: str = "No reason provided" + ) -> None: + """Remove timeout from a member.""" + await member.timeout(None, reason=f"{ctx.author}: {reason}") + await self._log_action(ctx.guild, member, ctx.author, "unmute", reason) + + embed = discord.Embed( + title="Timeout Removed", + description=f"{member.mention}'s timeout has been removed.", + color=discord.Color.green(), + timestamp=datetime.now(timezone.utc), + ) + embed.add_field(name="Reason", value=reason, inline=False) + embed.set_footer(text=f"Moderator: {ctx.author}") + + await ctx.send(embed=embed) + + @commands.command(name="kick") + @commands.has_permissions(kick_members=True) + @commands.guild_only() + async def kick( + self, ctx: commands.Context, member: discord.Member, *, reason: str = "No reason provided" + ) -> None: + """Kick a member from the server.""" + if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner: + await ctx.send("You cannot kick someone with a higher or equal role.") + return + + # Try to DM the user before kicking + try: + dm_embed = discord.Embed( + title=f"Kicked from {ctx.guild.name}", + description=f"You have been kicked from the server.", + color=discord.Color.red(), + ) + dm_embed.add_field(name="Reason", value=reason) + await member.send(embed=dm_embed) + except discord.Forbidden: + pass + + await member.kick(reason=f"{ctx.author}: {reason}") + await self._log_action(ctx.guild, member, ctx.author, "kick", reason) + + embed = discord.Embed( + title="Member Kicked", + description=f"{member} has been kicked from the server.", + color=discord.Color.red(), + timestamp=datetime.now(timezone.utc), + ) + embed.add_field(name="Reason", value=reason, inline=False) + embed.set_footer(text=f"Moderator: {ctx.author}") + + await ctx.send(embed=embed) + + @commands.command(name="ban") + @commands.has_permissions(ban_members=True) + @commands.guild_only() + async def ban( + self, + ctx: commands.Context, + member: discord.Member | discord.User, + *, + reason: str = "No reason provided", + ) -> None: + """Ban a member from the server.""" + if isinstance(member, discord.Member): + if member.top_role >= ctx.author.top_role and ctx.author != ctx.guild.owner: + await ctx.send("You cannot ban someone with a higher or equal role.") + return + + # Try to DM the user before banning + try: + dm_embed = discord.Embed( + title=f"Banned from {ctx.guild.name}", + description=f"You have been banned from the server.", + color=discord.Color.dark_red(), + ) + dm_embed.add_field(name="Reason", value=reason) + await member.send(embed=dm_embed) + except discord.Forbidden: + pass + + await ctx.guild.ban(member, reason=f"{ctx.author}: {reason}", delete_message_days=0) + await self._log_action(ctx.guild, member, ctx.author, "ban", reason) + + embed = discord.Embed( + title="Member Banned", + description=f"{member} has been banned from the server.", + color=discord.Color.dark_red(), + timestamp=datetime.now(timezone.utc), + ) + embed.add_field(name="Reason", value=reason, inline=False) + embed.set_footer(text=f"Moderator: {ctx.author}") + + await ctx.send(embed=embed) + + @commands.command(name="unban") + @commands.has_permissions(ban_members=True) + @commands.guild_only() + async def unban( + self, ctx: commands.Context, user_id: int, *, reason: str = "No reason provided" + ) -> None: + """Unban a user by their ID.""" + try: + user = await self.bot.fetch_user(user_id) + await ctx.guild.unban(user, reason=f"{ctx.author}: {reason}") + await self._log_action(ctx.guild, user, ctx.author, "unban", reason) + + embed = discord.Embed( + title="User Unbanned", + description=f"{user} has been unbanned.", + color=discord.Color.green(), + timestamp=datetime.now(timezone.utc), + ) + embed.add_field(name="Reason", value=reason, inline=False) + embed.set_footer(text=f"Moderator: {ctx.author}") + + await ctx.send(embed=embed) + + except discord.NotFound: + await ctx.send("User not found or not banned.") + except discord.Forbidden: + await ctx.send("I don't have permission to unban this user.") + + @commands.command(name="purge", aliases=["clear"]) + @commands.has_permissions(manage_messages=True) + @commands.guild_only() + async def purge(self, ctx: commands.Context, amount: int) -> None: + """Delete multiple messages at once (max 100).""" + if amount < 1 or amount > 100: + await ctx.send("Please specify a number between 1 and 100.") + return + + deleted = await ctx.channel.purge(limit=amount + 1) # +1 to include the command message + + msg = await ctx.send(f"Deleted {len(deleted) - 1} message(s).") + await msg.delete(delay=3) + + @commands.command(name="modlogs", aliases=["history"]) + @commands.has_permissions(kick_members=True) + @commands.guild_only() + async def modlogs(self, ctx: commands.Context, member: discord.Member | discord.User) -> None: + """View moderation history for a user.""" + async with self.bot.database.session() as session: + result = await session.execute( + select(ModerationLog) + .where(ModerationLog.guild_id == ctx.guild.id, ModerationLog.target_id == member.id) + .order_by(ModerationLog.created_at.desc()) + .limit(10) + ) + logs = result.scalars().all() + + embed = discord.Embed( + title=f"Moderation History for {member}", + color=discord.Color.blue(), + ) + + if logs: + for log in logs: + value = f"**Reason:** {log.reason or 'None'}\n**By:** {log.moderator_name}\n*{log.created_at.strftime('%Y-%m-%d %H:%M')}*" + embed.add_field(name=f"{log.action.upper()} (#{log.id})", value=value, inline=False) + else: + embed.description = "No moderation history found." + + await ctx.send(embed=embed) + + +async def setup(bot: GuardDen) -> None: + """Load the Moderation cog.""" + await bot.add_cog(Moderation(bot)) diff --git a/src/guardden/cogs/verification.py b/src/guardden/cogs/verification.py new file mode 100644 index 0000000..5c5b72b --- /dev/null +++ b/src/guardden/cogs/verification.py @@ -0,0 +1,423 @@ +"""Verification cog for new member verification.""" + +import logging +from datetime import datetime, timezone + +import discord +from discord import ui +from discord.ext import commands, tasks + +from guardden.bot import GuardDen +from guardden.services.verification import ( + ChallengeType, + PendingVerification, + VerificationService, +) + +logger = logging.getLogger(__name__) + + +class VerifyButton(ui.Button["VerificationView"]): + """Button for simple verification.""" + + def __init__(self) -> None: + super().__init__( + style=discord.ButtonStyle.success, + label="Verify", + custom_id="verify_button", + ) + + async def callback(self, interaction: discord.Interaction) -> None: + if self.view is None: + return + + success, message = await self.view.cog.complete_verification( + interaction.guild.id, + interaction.user.id, + "verified", + ) + + if success: + await interaction.response.send_message(message, ephemeral=True) + # Disable the button + self.disabled = True + self.label = "Verified" + await interaction.message.edit(view=self.view) + else: + await interaction.response.send_message(message, ephemeral=True) + + +class EmojiButton(ui.Button["EmojiVerificationView"]): + """Button for emoji selection verification.""" + + def __init__(self, emoji: str, row: int = 0) -> None: + super().__init__( + style=discord.ButtonStyle.secondary, + label=emoji, + custom_id=f"emoji_{emoji}", + row=row, + ) + self.emoji_value = emoji + + async def callback(self, interaction: discord.Interaction) -> None: + if self.view is None: + return + + success, message = await self.view.cog.complete_verification( + interaction.guild.id, + interaction.user.id, + self.emoji_value, + ) + + if success: + await interaction.response.send_message(message, ephemeral=True) + # Disable all buttons + for item in self.view.children: + if isinstance(item, ui.Button): + item.disabled = True + await interaction.message.edit(view=self.view) + else: + await interaction.response.send_message(message, ephemeral=True) + + +class VerificationView(ui.View): + """View for button verification.""" + + def __init__(self, cog: "Verification", timeout: float = 600) -> None: + super().__init__(timeout=timeout) + self.cog = cog + self.add_item(VerifyButton()) + + +class EmojiVerificationView(ui.View): + """View for emoji selection verification.""" + + def __init__(self, cog: "Verification", options: list[str], timeout: float = 600) -> None: + super().__init__(timeout=timeout) + self.cog = cog + for i, emoji in enumerate(options): + self.add_item(EmojiButton(emoji, row=i // 4)) + + +class CaptchaModal(ui.Modal): + """Modal for captcha/math input.""" + + answer = ui.TextInput( + label="Your Answer", + placeholder="Enter the answer here...", + max_length=50, + ) + + def __init__(self, cog: "Verification", title: str = "Verification") -> None: + super().__init__(title=title) + self.cog = cog + + async def on_submit(self, interaction: discord.Interaction) -> None: + success, message = await self.cog.complete_verification( + interaction.guild.id, + interaction.user.id, + self.answer.value, + ) + await interaction.response.send_message(message, ephemeral=True) + + +class AnswerButton(ui.Button["AnswerView"]): + """Button to open the answer modal.""" + + def __init__(self) -> None: + super().__init__( + style=discord.ButtonStyle.primary, + label="Submit Answer", + custom_id="submit_answer", + ) + + async def callback(self, interaction: discord.Interaction) -> None: + if self.view is None: + return + modal = CaptchaModal(self.view.cog) + await interaction.response.send_modal(modal) + + +class AnswerView(ui.View): + """View with button to open answer modal.""" + + def __init__(self, cog: "Verification", timeout: float = 600) -> None: + super().__init__(timeout=timeout) + self.cog = cog + self.add_item(AnswerButton()) + + +class Verification(commands.Cog): + """Member verification system.""" + + def __init__(self, bot: GuardDen) -> None: + self.bot = bot + self.service = VerificationService() + self.cleanup_task.start() + + def cog_unload(self) -> None: + self.cleanup_task.cancel() + + @tasks.loop(minutes=5) + async def cleanup_task(self) -> None: + """Periodically clean up expired verifications.""" + count = self.service.cleanup_expired() + if count > 0: + logger.debug(f"Cleaned up {count} expired verifications") + + @cleanup_task.before_loop + async def before_cleanup(self) -> None: + await self.bot.wait_until_ready() + + async def complete_verification( + self, guild_id: int, user_id: int, response: str + ) -> tuple[bool, str]: + """Complete a verification and assign role if successful.""" + success, message = self.service.verify(guild_id, user_id, response) + + if success: + # Assign verified role + guild = self.bot.get_guild(guild_id) + if guild: + member = guild.get_member(user_id) + config = await self.bot.guild_config.get_config(guild_id) + + if member and config and config.verified_role_id: + role = guild.get_role(config.verified_role_id) + if role: + try: + await member.add_roles(role, reason="Verification completed") + logger.info(f"Verified {member} in {guild.name}") + except discord.Forbidden: + logger.warning(f"Cannot assign verified role in {guild.name}") + + return success, message + + async def send_verification( + self, + member: discord.Member, + channel: discord.TextChannel, + challenge_type: ChallengeType, + ) -> None: + """Send a verification challenge to a member.""" + pending = self.service.create_challenge( + user_id=member.id, + guild_id=member.guild.id, + challenge_type=challenge_type, + ) + + embed = discord.Embed( + title="Verification Required", + description=pending.challenge.question, + color=discord.Color.blue(), + timestamp=datetime.now(timezone.utc), + ) + embed.set_footer( + text=f"Expires in 10 minutes • {pending.challenge.max_attempts} attempts allowed" + ) + + # Create appropriate view based on challenge type + if challenge_type == ChallengeType.BUTTON: + view = VerificationView(self) + elif challenge_type == ChallengeType.EMOJI: + view = EmojiVerificationView(self, pending.challenge.options) + else: + # Captcha or Math - use modal + view = AnswerView(self) + + try: + # Try to DM the user first + dm_channel = await member.create_dm() + msg = await dm_channel.send(embed=embed, view=view) + pending.message_id = msg.id + pending.channel_id = dm_channel.id + except discord.Forbidden: + # Fall back to channel mention + msg = await channel.send( + content=member.mention, + embed=embed, + view=view, + ) + pending.message_id = msg.id + pending.channel_id = channel.id + + @commands.Cog.listener() + async def on_member_join(self, member: discord.Member) -> None: + """Handle new member joins for verification.""" + if member.bot: + return + + config = await self.bot.guild_config.get_config(member.guild.id) + if not config or not config.verification_enabled: + return + + # Determine verification channel + channel_id = config.welcome_channel_id or config.log_channel_id + if not channel_id: + return + + channel = member.guild.get_channel(channel_id) + if not channel or not isinstance(channel, discord.TextChannel): + return + + # Get challenge type from config + try: + challenge_type = ChallengeType(config.verification_type) + except ValueError: + challenge_type = ChallengeType.BUTTON + + await self.send_verification(member, channel, challenge_type) + + @commands.group(name="verify", invoke_without_command=True) + @commands.guild_only() + async def verify_cmd(self, ctx: commands.Context) -> None: + """Request a verification challenge.""" + config = await self.bot.guild_config.get_config(ctx.guild.id) + + if not config or not config.verification_enabled: + await ctx.send("Verification is not enabled on this server.") + return + + # Check if already verified + if config.verified_role_id: + role = ctx.guild.get_role(config.verified_role_id) + if role and role in ctx.author.roles: + await ctx.send("You are already verified!") + return + + # Check for existing pending verification + pending = self.service.get_pending(ctx.guild.id, ctx.author.id) + if pending and not pending.challenge.is_expired: + await ctx.send("You already have a pending verification. Please complete it first.") + return + + # Get challenge type + try: + challenge_type = ChallengeType(config.verification_type) + except ValueError: + challenge_type = ChallengeType.BUTTON + + await self.send_verification(ctx.author, ctx.channel, challenge_type) + await ctx.message.delete(delay=1) + + @verify_cmd.command(name="setup") + @commands.has_permissions(administrator=True) + @commands.guild_only() + async def verify_setup(self, ctx: commands.Context) -> None: + """View verification setup status.""" + config = await self.bot.guild_config.get_config(ctx.guild.id) + + embed = discord.Embed( + title="Verification Setup", + color=discord.Color.blue(), + ) + + embed.add_field( + name="Enabled", + value="✅ Yes" if config and config.verification_enabled else "❌ No", + inline=True, + ) + embed.add_field( + name="Type", + value=config.verification_type if config else "button", + inline=True, + ) + + if config and config.verified_role_id: + role = ctx.guild.get_role(config.verified_role_id) + embed.add_field( + name="Verified Role", + value=role.mention if role else "Not found", + inline=True, + ) + else: + embed.add_field(name="Verified Role", value="Not set", inline=True) + + pending_count = self.service.get_pending_count(ctx.guild.id) + embed.add_field(name="Pending Verifications", value=str(pending_count), inline=True) + + await ctx.send(embed=embed) + + @verify_cmd.command(name="enable") + @commands.has_permissions(administrator=True) + @commands.guild_only() + async def verify_enable(self, ctx: commands.Context) -> None: + """Enable verification for new members.""" + config = await self.bot.guild_config.get_config(ctx.guild.id) + + if not config or not config.verified_role_id: + await ctx.send("Please set a verified role first with `!verify role @role`") + return + + await self.bot.guild_config.update_settings(ctx.guild.id, verification_enabled=True) + await ctx.send("✅ Verification enabled for new members.") + + @verify_cmd.command(name="disable") + @commands.has_permissions(administrator=True) + @commands.guild_only() + async def verify_disable(self, ctx: commands.Context) -> None: + """Disable verification.""" + await self.bot.guild_config.update_settings(ctx.guild.id, verification_enabled=False) + await ctx.send("❌ Verification disabled.") + + @verify_cmd.command(name="role") + @commands.has_permissions(administrator=True) + @commands.guild_only() + async def verify_role(self, ctx: commands.Context, role: discord.Role) -> None: + """Set the role given upon verification.""" + await self.bot.guild_config.update_settings(ctx.guild.id, verified_role_id=role.id) + await ctx.send(f"Verified role set to {role.mention}") + + @verify_cmd.command(name="type") + @commands.has_permissions(administrator=True) + @commands.guild_only() + async def verify_type(self, ctx: commands.Context, vtype: str) -> None: + """Set verification type (button, captcha, math, emoji).""" + try: + challenge_type = ChallengeType(vtype.lower()) + except ValueError: + valid = ", ".join(t.value for t in ChallengeType if t != ChallengeType.QUESTIONS) + await ctx.send(f"Invalid type. Valid options: {valid}") + return + + await self.bot.guild_config.update_settings( + ctx.guild.id, verification_type=challenge_type.value + ) + await ctx.send(f"Verification type set to **{challenge_type.value}**") + + @verify_cmd.command(name="test") + @commands.has_permissions(administrator=True) + @commands.guild_only() + async def verify_test(self, ctx: commands.Context, vtype: str = "button") -> None: + """Test verification (sends challenge to you).""" + try: + challenge_type = ChallengeType(vtype.lower()) + except ValueError: + challenge_type = ChallengeType.BUTTON + + await self.send_verification(ctx.author, ctx.channel, challenge_type) + + @verify_cmd.command(name="reset") + @commands.has_permissions(kick_members=True) + @commands.guild_only() + async def verify_reset(self, ctx: commands.Context, member: discord.Member) -> None: + """Reset verification for a member (remove role and cancel pending).""" + # Cancel any pending verification + self.service.cancel(ctx.guild.id, member.id) + + # Remove verified role + config = await self.bot.guild_config.get_config(ctx.guild.id) + if config and config.verified_role_id: + role = ctx.guild.get_role(config.verified_role_id) + if role and role in member.roles: + try: + await member.remove_roles(role, reason=f"Verification reset by {ctx.author}") + except discord.Forbidden: + pass + + await ctx.send(f"Reset verification for {member.mention}") + + +async def setup(bot: GuardDen) -> None: + """Load the Verification cog.""" + await bot.add_cog(Verification(bot)) diff --git a/src/guardden/config.py b/src/guardden/config.py new file mode 100644 index 0000000..b1fd160 --- /dev/null +++ b/src/guardden/config.py @@ -0,0 +1,50 @@ +"""Configuration management for GuardDen.""" + +from pathlib import Path +from typing import Literal + +from pydantic import Field, SecretStr +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """Application settings loaded from environment variables.""" + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + case_sensitive=False, + env_prefix="GUARDDEN_", + ) + + # Discord settings + discord_token: SecretStr = Field(..., description="Discord bot token") + discord_prefix: str = Field(default="!", description="Default command prefix") + + # Database settings + database_url: SecretStr = Field( + default=SecretStr("postgresql://guardden:guardden@localhost:5432/guardden"), + description="PostgreSQL connection URL", + ) + database_pool_min: int = Field(default=5, description="Minimum database pool size") + database_pool_max: int = Field(default=20, description="Maximum database pool size") + + # AI settings (optional) + ai_provider: Literal["anthropic", "openai", "none"] = Field( + default="none", description="AI provider for content moderation" + ) + anthropic_api_key: SecretStr | None = Field(default=None, description="Anthropic API key") + openai_api_key: SecretStr | None = Field(default=None, description="OpenAI API key") + + # Logging + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = Field( + default="INFO", description="Logging level" + ) + + # Paths + data_dir: Path = Field(default=Path("data"), description="Data directory for persistent files") + + +def get_settings() -> Settings: + """Get application settings instance.""" + return Settings() diff --git a/src/guardden/models/__init__.py b/src/guardden/models/__init__.py new file mode 100644 index 0000000..c14c949 --- /dev/null +++ b/src/guardden/models/__init__.py @@ -0,0 +1,15 @@ +"""Database models for GuardDen.""" + +from guardden.models.base import Base +from guardden.models.guild import BannedWord, Guild, GuildSettings +from guardden.models.moderation import ModerationLog, Strike, UserNote + +__all__ = [ + "Base", + "Guild", + "GuildSettings", + "BannedWord", + "ModerationLog", + "Strike", + "UserNote", +] diff --git a/src/guardden/models/base.py b/src/guardden/models/base.py new file mode 100644 index 0000000..aef6902 --- /dev/null +++ b/src/guardden/models/base.py @@ -0,0 +1,32 @@ +"""Base model and database utilities.""" + +from datetime import datetime + +from sqlalchemy import BigInteger, DateTime, func +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + +class Base(DeclarativeBase): + """Base class for all database models.""" + + pass + + +class TimestampMixin: + """Mixin that adds created_at and updated_at timestamps.""" + + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) + + +# Type alias for Discord snowflake IDs (64-bit integers) +SnowflakeID = BigInteger diff --git a/src/guardden/models/guild.py b/src/guardden/models/guild.py new file mode 100644 index 0000000..b146b70 --- /dev/null +++ b/src/guardden/models/guild.py @@ -0,0 +1,117 @@ +"""Guild-related database models.""" + +from datetime import datetime +from typing import TYPE_CHECKING + +from sqlalchemy import Boolean, ForeignKey, Integer, String, Text +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from guardden.models.base import Base, SnowflakeID, TimestampMixin + +if TYPE_CHECKING: + from guardden.models.moderation import ModerationLog, Strike + + +class Guild(Base, TimestampMixin): + """Represents a Discord guild (server) configuration.""" + + __tablename__ = "guilds" + + id: Mapped[int] = mapped_column(SnowflakeID, primary_key=True) + name: Mapped[str] = mapped_column(String(100), nullable=False) + owner_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False) + premium: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + + # Relationships + settings: Mapped["GuildSettings"] = relationship( + back_populates="guild", uselist=False, cascade="all, delete-orphan" + ) + banned_words: Mapped[list["BannedWord"]] = relationship( + back_populates="guild", cascade="all, delete-orphan" + ) + moderation_logs: Mapped[list["ModerationLog"]] = relationship( + back_populates="guild", cascade="all, delete-orphan" + ) + strikes: Mapped[list["Strike"]] = relationship( + back_populates="guild", cascade="all, delete-orphan" + ) + + +class GuildSettings(Base, TimestampMixin): + """Per-guild bot settings and configuration.""" + + __tablename__ = "guild_settings" + + guild_id: Mapped[int] = mapped_column( + SnowflakeID, ForeignKey("guilds.id", ondelete="CASCADE"), primary_key=True + ) + + # General settings + prefix: Mapped[str] = mapped_column(String(10), default="!", nullable=False) + locale: Mapped[str] = mapped_column(String(10), default="en", nullable=False) + + # Channel configuration (stored as snowflake IDs) + log_channel_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True) + mod_log_channel_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True) + welcome_channel_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True) + + # Role configuration + mute_role_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True) + verified_role_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True) + mod_role_ids: Mapped[dict] = mapped_column(JSONB, default=list, nullable=False) + + # Moderation settings + automod_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + anti_spam_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + link_filter_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + + # Strike thresholds (actions at each threshold) + strike_actions: Mapped[dict] = mapped_column( + JSONB, + default=lambda: { + "1": {"action": "warn"}, + "3": {"action": "timeout", "duration": 3600}, + "5": {"action": "kick"}, + "7": {"action": "ban"}, + }, + nullable=False, + ) + + # AI moderation settings + ai_moderation_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + ai_sensitivity: Mapped[int] = mapped_column(Integer, default=50, nullable=False) # 0-100 scale + nsfw_detection_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + + # Verification settings + verification_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + verification_type: Mapped[str] = mapped_column( + String(20), default="button", nullable=False + ) # button, captcha, questions + + # Relationship + guild: Mapped["Guild"] = relationship(back_populates="settings") + + +class BannedWord(Base, TimestampMixin): + """Banned words/phrases for a guild with regex support.""" + + __tablename__ = "banned_words" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + guild_id: Mapped[int] = mapped_column( + SnowflakeID, ForeignKey("guilds.id", ondelete="CASCADE"), nullable=False + ) + + pattern: Mapped[str] = mapped_column(Text, nullable=False) + is_regex: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + action: Mapped[str] = mapped_column( + String(20), default="delete", nullable=False + ) # delete, warn, strike + reason: Mapped[str | None] = mapped_column(Text, nullable=True) + + # Who added this and when + added_by: Mapped[int] = mapped_column(SnowflakeID, nullable=False) + + # Relationship + guild: Mapped["Guild"] = relationship(back_populates="banned_words") diff --git a/src/guardden/models/moderation.py b/src/guardden/models/moderation.py new file mode 100644 index 0000000..e4f8744 --- /dev/null +++ b/src/guardden/models/moderation.py @@ -0,0 +1,101 @@ +"""Moderation-related database models.""" + +from datetime import datetime +from enum import Enum + +from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, Text +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from guardden.models.base import Base, SnowflakeID, TimestampMixin +from guardden.models.guild import Guild + + +class ModAction(str, Enum): + """Types of moderation actions.""" + + WARN = "warn" + TIMEOUT = "timeout" + KICK = "kick" + BAN = "ban" + UNBAN = "unban" + UNMUTE = "unmute" + NOTE = "note" + STRIKE = "strike" + DELETE = "delete" + + +class ModerationLog(Base, TimestampMixin): + """Log of all moderation actions taken.""" + + __tablename__ = "moderation_logs" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + guild_id: Mapped[int] = mapped_column( + SnowflakeID, ForeignKey("guilds.id", ondelete="CASCADE"), nullable=False + ) + + # Target and moderator + target_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False) + target_name: Mapped[str] = mapped_column(String(100), nullable=False) + moderator_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False) + moderator_name: Mapped[str] = mapped_column(String(100), nullable=False) + + # Action details + action: Mapped[str] = mapped_column(String(20), nullable=False) + reason: Mapped[str | None] = mapped_column(Text, nullable=True) + duration: Mapped[int | None] = mapped_column(Integer, nullable=True) # Duration in seconds + expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + + # Context + channel_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True) + message_id: Mapped[int | None] = mapped_column(SnowflakeID, nullable=True) + message_content: Mapped[str | None] = mapped_column(Text, nullable=True) + + # Was this an automatic action? + is_automatic: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + + # Relationship + guild: Mapped["Guild"] = relationship(back_populates="moderation_logs") + + +class Strike(Base, TimestampMixin): + """User strikes/warnings tracking.""" + + __tablename__ = "strikes" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + guild_id: Mapped[int] = mapped_column( + SnowflakeID, ForeignKey("guilds.id", ondelete="CASCADE"), nullable=False + ) + + user_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False) + user_name: Mapped[str] = mapped_column(String(100), nullable=False) + moderator_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False) + + reason: Mapped[str] = mapped_column(Text, nullable=False) + points: Mapped[int] = mapped_column(Integer, default=1, nullable=False) + + # Strikes can expire + expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + + # Reference to the moderation log entry + mod_log_id: Mapped[int | None] = mapped_column( + Integer, ForeignKey("moderation_logs.id", ondelete="SET NULL"), nullable=True + ) + + # Relationship + guild: Mapped["Guild"] = relationship(back_populates="strikes") + + +class UserNote(Base, TimestampMixin): + """Moderator notes on users.""" + + __tablename__ = "user_notes" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + guild_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False) + + user_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False) + moderator_id: Mapped[int] = mapped_column(SnowflakeID, nullable=False) + content: Mapped[str] = mapped_column(Text, nullable=False) diff --git a/src/guardden/services/__init__.py b/src/guardden/services/__init__.py new file mode 100644 index 0000000..7cfdd5f --- /dev/null +++ b/src/guardden/services/__init__.py @@ -0,0 +1,16 @@ +"""Services for GuardDen.""" + +from guardden.services.automod import AutomodService +from guardden.services.database import Database +from guardden.services.ratelimit import RateLimiter, get_rate_limiter, ratelimit +from guardden.services.verification import ChallengeType, VerificationService + +__all__ = [ + "AutomodService", + "ChallengeType", + "Database", + "RateLimiter", + "VerificationService", + "get_rate_limiter", + "ratelimit", +] diff --git a/src/guardden/services/ai/__init__.py b/src/guardden/services/ai/__init__.py new file mode 100644 index 0000000..1585fad --- /dev/null +++ b/src/guardden/services/ai/__init__.py @@ -0,0 +1,6 @@ +"""AI services for content moderation.""" + +from guardden.services.ai.base import AIProvider, ModerationResult +from guardden.services.ai.factory import create_ai_provider + +__all__ = ["AIProvider", "ModerationResult", "create_ai_provider"] diff --git a/src/guardden/services/ai/anthropic_provider.py b/src/guardden/services/ai/anthropic_provider.py new file mode 100644 index 0000000..3b3e9ff --- /dev/null +++ b/src/guardden/services/ai/anthropic_provider.py @@ -0,0 +1,261 @@ +"""Anthropic Claude AI provider implementation.""" + +import logging +from typing import Any + +from guardden.services.ai.base import ( + AIProvider, + ContentCategory, + ImageAnalysisResult, + ModerationResult, + PhishingAnalysisResult, +) + +logger = logging.getLogger(__name__) + +# Content moderation system prompt +MODERATION_SYSTEM_PROMPT = """You are a content moderation AI for a Discord server. Analyze the given message and determine if it violates community guidelines. + +Categories to check: +- harassment: Personal attacks, bullying, intimidation +- hate_speech: Discrimination, slurs, dehumanization based on identity +- sexual: Explicit sexual content, sexual solicitation +- violence: Threats, graphic violence, encouraging harm +- self_harm: Suicide, self-injury content or encouragement +- spam: Repetitive, promotional, or low-quality content +- scam: Phishing attempts, fraudulent offers, impersonation +- misinformation: Dangerous false information + +Respond in this exact JSON format: +{ + "is_flagged": true/false, + "confidence": 0.0-1.0, + "categories": ["category1", "category2"], + "explanation": "Brief explanation", + "suggested_action": "none/warn/delete/timeout/ban" +} + +Be balanced - flag genuinely problematic content but allow normal conversation, jokes, and mild language. Consider context.""" + +IMAGE_ANALYSIS_PROMPT = """Analyze this image for content moderation purposes. Check for: +- NSFW content (nudity, sexual content) +- Violence or gore +- Disturbing or shocking content +- Any content inappropriate for a general audience + +Respond in this exact JSON format: +{ + "is_nsfw": true/false, + "is_violent": true/false, + "is_disturbing": true/false, + "confidence": 0.0-1.0, + "description": "Brief description of the image", + "categories": ["category1", "category2"] +} + +Be accurate but not overly sensitive - artistic nudity or mild violence in appropriate contexts may be acceptable.""" + +PHISHING_ANALYSIS_PROMPT = """Analyze this URL and message context for phishing or scam indicators. + +Check for: +- Domain impersonation (typosquatting, lookalike domains) +- Urgency tactics ("act now", "limited time") +- Requests for credentials or personal info +- Too-good-to-be-true offers +- Suspicious redirects or URL shorteners +- Mismatched or hidden URLs + +Respond in this exact JSON format: +{ + "is_phishing": true/false, + "confidence": 0.0-1.0, + "risk_factors": ["factor1", "factor2"], + "explanation": "Brief explanation" +}""" + + +class AnthropicProvider(AIProvider): + """AI provider using Anthropic's Claude API.""" + + def __init__(self, api_key: str, model: str = "claude-3-haiku-20240307") -> None: + """ + Initialize Anthropic provider. + + Args: + api_key: Anthropic API key + model: Model to use (default: claude-3-haiku for speed/cost) + """ + try: + import anthropic + except ImportError: + raise ImportError("anthropic package required. Install with: pip install anthropic") + + self.client = anthropic.AsyncAnthropic(api_key=api_key) + self.model = model + logger.info(f"Initialized Anthropic provider with model: {model}") + + async def _call_api(self, system: str, user_content: Any, max_tokens: int = 500) -> str: + """Make an API call to Claude.""" + try: + message = await self.client.messages.create( + model=self.model, + max_tokens=max_tokens, + system=system, + messages=[{"role": "user", "content": user_content}], + ) + return message.content[0].text + except Exception as e: + logger.error(f"Anthropic API error: {e}") + raise + + def _parse_json_response(self, response: str) -> dict: + """Parse JSON from response, handling markdown code blocks.""" + import json + + # Remove markdown code blocks if present + text = response.strip() + if text.startswith("```"): + lines = text.split("\n") + # Remove first and last lines (```json and ```) + text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:]) + + return json.loads(text) + + async def moderate_text( + self, + content: str, + context: str | None = None, + sensitivity: int = 50, + ) -> ModerationResult: + """Analyze text content for policy violations.""" + # Adjust prompt based on sensitivity + sensitivity_note = "" + if sensitivity < 30: + sensitivity_note = "\n\nBe lenient - only flag clearly problematic content." + elif sensitivity > 70: + sensitivity_note = "\n\nBe strict - flag anything potentially problematic." + + system = MODERATION_SYSTEM_PROMPT + sensitivity_note + + user_message = f"Message to analyze:\n{content}" + if context: + user_message = f"Context: {context}\n\n{user_message}" + + try: + response = await self._call_api(system, user_message) + data = self._parse_json_response(response) + + categories = [ + ContentCategory(cat) + for cat in data.get("categories", []) + if cat in ContentCategory.__members__.values() + ] + + return ModerationResult( + is_flagged=data.get("is_flagged", False), + confidence=float(data.get("confidence", 0.0)), + categories=categories, + explanation=data.get("explanation", ""), + suggested_action=data.get("suggested_action", "none"), + ) + + except Exception as e: + logger.error(f"Error moderating text: {e}") + return ModerationResult( + is_flagged=False, + explanation=f"Error analyzing content: {str(e)}", + ) + + async def analyze_image( + self, + image_url: str, + sensitivity: int = 50, + ) -> ImageAnalysisResult: + """Analyze an image for NSFW or inappropriate content.""" + import base64 + + import aiohttp + + sensitivity_note = "" + if sensitivity < 30: + sensitivity_note = "\n\nBe lenient - only flag explicit content." + elif sensitivity > 70: + sensitivity_note = "\n\nBe strict - flag suggestive content as well." + + system = IMAGE_ANALYSIS_PROMPT + sensitivity_note + + try: + # Download image and convert to base64 + async with aiohttp.ClientSession() as session: + async with session.get(image_url) as resp: + if resp.status != 200: + return ImageAnalysisResult( + description=f"Failed to download image: HTTP {resp.status}" + ) + + content_type = resp.content_type or "image/jpeg" + image_data = await resp.read() + + # Check file size (max 20MB for Claude) + if len(image_data) > 20 * 1024 * 1024: + return ImageAnalysisResult(description="Image too large to analyze") + + base64_image = base64.standard_b64encode(image_data).decode("utf-8") + + # Create multimodal message + user_content = [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": content_type, + "data": base64_image, + }, + }, + {"type": "text", "text": "Analyze this image for content moderation."}, + ] + + response = await self._call_api(system, user_content) + data = self._parse_json_response(response) + + return ImageAnalysisResult( + is_nsfw=data.get("is_nsfw", False), + is_violent=data.get("is_violent", False), + is_disturbing=data.get("is_disturbing", False), + confidence=float(data.get("confidence", 0.0)), + description=data.get("description", ""), + categories=data.get("categories", []), + ) + + except Exception as e: + logger.error(f"Error analyzing image: {e}") + return ImageAnalysisResult(description=f"Error analyzing image: {str(e)}") + + async def analyze_phishing( + self, + url: str, + message_content: str | None = None, + ) -> PhishingAnalysisResult: + """Analyze a URL for phishing/scam indicators.""" + user_message = f"URL to analyze: {url}" + if message_content: + user_message += f"\n\nFull message context:\n{message_content}" + + try: + response = await self._call_api(PHISHING_ANALYSIS_PROMPT, user_message) + data = self._parse_json_response(response) + + return PhishingAnalysisResult( + is_phishing=data.get("is_phishing", False), + confidence=float(data.get("confidence", 0.0)), + risk_factors=data.get("risk_factors", []), + explanation=data.get("explanation", ""), + ) + + except Exception as e: + logger.error(f"Error analyzing phishing: {e}") + return PhishingAnalysisResult(explanation=f"Error analyzing URL: {str(e)}") + + async def close(self) -> None: + """Clean up resources.""" + await self.client.close() diff --git a/src/guardden/services/ai/base.py b/src/guardden/services/ai/base.py new file mode 100644 index 0000000..cb0a231 --- /dev/null +++ b/src/guardden/services/ai/base.py @@ -0,0 +1,149 @@ +"""Base classes for AI providers.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import Literal + + +class ContentCategory(str, Enum): + """Categories of problematic content.""" + + SAFE = "safe" + HARASSMENT = "harassment" + HATE_SPEECH = "hate_speech" + SEXUAL = "sexual" + VIOLENCE = "violence" + SELF_HARM = "self_harm" + SPAM = "spam" + SCAM = "scam" + MISINFORMATION = "misinformation" + + +@dataclass +class ModerationResult: + """Result of AI content moderation.""" + + is_flagged: bool = False + confidence: float = 0.0 # 0.0 to 1.0 + categories: list[ContentCategory] = field(default_factory=list) + explanation: str = "" + suggested_action: Literal["none", "warn", "delete", "timeout", "ban"] = "none" + + @property + def severity(self) -> int: + """Get severity score 0-100 based on confidence and categories.""" + if not self.is_flagged: + return 0 + + # Base severity from confidence + severity = int(self.confidence * 50) + + # Add severity based on category + high_severity = { + ContentCategory.HATE_SPEECH, + ContentCategory.SELF_HARM, + ContentCategory.SCAM, + } + medium_severity = { + ContentCategory.HARASSMENT, + ContentCategory.VIOLENCE, + ContentCategory.SEXUAL, + } + + for cat in self.categories: + if cat in high_severity: + severity += 30 + elif cat in medium_severity: + severity += 20 + else: + severity += 10 + + return min(severity, 100) + + +@dataclass +class ImageAnalysisResult: + """Result of AI image analysis.""" + + is_nsfw: bool = False + is_violent: bool = False + is_disturbing: bool = False + confidence: float = 0.0 + description: str = "" + categories: list[str] = field(default_factory=list) + + +@dataclass +class PhishingAnalysisResult: + """Result of AI phishing/scam analysis.""" + + is_phishing: bool = False + confidence: float = 0.0 + risk_factors: list[str] = field(default_factory=list) + explanation: str = "" + + +class AIProvider(ABC): + """Abstract base class for AI providers.""" + + @abstractmethod + async def moderate_text( + self, + content: str, + context: str | None = None, + sensitivity: int = 50, + ) -> ModerationResult: + """ + Analyze text content for policy violations. + + Args: + content: The text to analyze + context: Optional context about the conversation/server + sensitivity: 0-100, higher means more strict + + Returns: + ModerationResult with analysis + """ + pass + + @abstractmethod + async def analyze_image( + self, + image_url: str, + sensitivity: int = 50, + ) -> ImageAnalysisResult: + """ + Analyze an image for NSFW or inappropriate content. + + Args: + image_url: URL of the image to analyze + sensitivity: 0-100, higher means more strict + + Returns: + ImageAnalysisResult with analysis + """ + pass + + @abstractmethod + async def analyze_phishing( + self, + url: str, + message_content: str | None = None, + ) -> PhishingAnalysisResult: + """ + Analyze a URL for phishing/scam indicators. + + Args: + url: The URL to analyze + message_content: Optional full message for context + + Returns: + PhishingAnalysisResult with analysis + """ + pass + + @abstractmethod + async def close(self) -> None: + """Clean up resources.""" + pass diff --git a/src/guardden/services/ai/factory.py b/src/guardden/services/ai/factory.py new file mode 100644 index 0000000..bbb7e22 --- /dev/null +++ b/src/guardden/services/ai/factory.py @@ -0,0 +1,67 @@ +"""Factory for creating AI providers.""" + +import logging +from typing import Literal + +from guardden.services.ai.base import AIProvider + +logger = logging.getLogger(__name__) + + +class NullProvider(AIProvider): + """Null provider that does nothing (for when AI is disabled).""" + + async def moderate_text(self, content, context=None, sensitivity=50): + from guardden.services.ai.base import ModerationResult + + return ModerationResult() + + async def analyze_image(self, image_url, sensitivity=50): + from guardden.services.ai.base import ImageAnalysisResult + + return ImageAnalysisResult() + + async def analyze_phishing(self, url, message_content=None): + from guardden.services.ai.base import PhishingAnalysisResult + + return PhishingAnalysisResult() + + async def close(self): + pass + + +def create_ai_provider( + provider: Literal["anthropic", "openai", "none"], + api_key: str | None = None, +) -> AIProvider: + """ + Create an AI provider instance. + + Args: + provider: The provider type to create + api_key: API key for the provider + + Returns: + AIProvider instance + + Raises: + ValueError: If provider is unknown or API key is missing + """ + if provider == "none": + logger.info("AI moderation disabled") + return NullProvider() + + if not api_key: + raise ValueError(f"API key required for {provider} provider") + + if provider == "anthropic": + from guardden.services.ai.anthropic_provider import AnthropicProvider + + return AnthropicProvider(api_key) + + if provider == "openai": + from guardden.services.ai.openai_provider import OpenAIProvider + + return OpenAIProvider(api_key) + + raise ValueError(f"Unknown AI provider: {provider}") diff --git a/src/guardden/services/ai/openai_provider.py b/src/guardden/services/ai/openai_provider.py new file mode 100644 index 0000000..4b1c6e2 --- /dev/null +++ b/src/guardden/services/ai/openai_provider.py @@ -0,0 +1,213 @@ +"""OpenAI AI provider implementation.""" + +import logging +from typing import Any + +from guardden.services.ai.base import ( + AIProvider, + ContentCategory, + ImageAnalysisResult, + ModerationResult, + PhishingAnalysisResult, +) + +logger = logging.getLogger(__name__) + + +class OpenAIProvider(AIProvider): + """AI provider using OpenAI's API.""" + + def __init__(self, api_key: str, model: str = "gpt-4o-mini") -> None: + """ + Initialize OpenAI provider. + + Args: + api_key: OpenAI API key + model: Model to use (default: gpt-4o-mini for speed/cost) + """ + try: + import openai + except ImportError: + raise ImportError("openai package required. Install with: pip install openai") + + self.client = openai.AsyncOpenAI(api_key=api_key) + self.model = model + logger.info(f"Initialized OpenAI provider with model: {model}") + + async def _call_api( + self, + system: str, + user_content: Any, + max_tokens: int = 500, + ) -> str: + """Make an API call to OpenAI.""" + try: + response = await self.client.chat.completions.create( + model=self.model, + max_tokens=max_tokens, + messages=[ + {"role": "system", "content": system}, + {"role": "user", "content": user_content}, + ], + response_format={"type": "json_object"}, + ) + return response.choices[0].message.content or "" + except Exception as e: + logger.error(f"OpenAI API error: {e}") + raise + + def _parse_json_response(self, response: str) -> dict: + """Parse JSON from response.""" + import json + + return json.loads(response) + + async def moderate_text( + self, + content: str, + context: str | None = None, + sensitivity: int = 50, + ) -> ModerationResult: + """Analyze text content for policy violations.""" + # First, use OpenAI's built-in moderation API for quick check + try: + mod_response = await self.client.moderations.create(input=content) + results = mod_response.results[0] + + # Map OpenAI categories to our categories + category_mapping = { + "harassment": ContentCategory.HARASSMENT, + "harassment/threatening": ContentCategory.HARASSMENT, + "hate": ContentCategory.HATE_SPEECH, + "hate/threatening": ContentCategory.HATE_SPEECH, + "self-harm": ContentCategory.SELF_HARM, + "self-harm/intent": ContentCategory.SELF_HARM, + "self-harm/instructions": ContentCategory.SELF_HARM, + "sexual": ContentCategory.SEXUAL, + "sexual/minors": ContentCategory.SEXUAL, + "violence": ContentCategory.VIOLENCE, + "violence/graphic": ContentCategory.VIOLENCE, + } + + flagged_categories = [] + max_score = 0.0 + + for category, score in results.category_scores.model_dump().items(): + if score > 0.5: # Threshold + if category in category_mapping: + flagged_categories.append(category_mapping[category]) + max_score = max(max_score, score) + + # Adjust threshold based on sensitivity + threshold = 0.3 + (0.4 * (100 - sensitivity) / 100) # 0.3 to 0.7 + + if results.flagged or max_score > threshold: + return ModerationResult( + is_flagged=True, + confidence=max_score, + categories=list(set(flagged_categories)), + explanation="Content flagged by moderation API", + suggested_action="delete" if max_score > 0.8 else "warn", + ) + + return ModerationResult(is_flagged=False, confidence=1.0 - max_score) + + except Exception as e: + logger.error(f"Error moderating text: {e}") + return ModerationResult( + is_flagged=False, + explanation=f"Error analyzing content: {str(e)}", + ) + + async def analyze_image( + self, + image_url: str, + sensitivity: int = 50, + ) -> ImageAnalysisResult: + """Analyze an image for NSFW or inappropriate content.""" + system = """Analyze this image for content moderation. Respond in JSON format: +{ + "is_nsfw": true/false, + "is_violent": true/false, + "is_disturbing": true/false, + "confidence": 0.0-1.0, + "description": "Brief description", + "categories": ["category1"] +}""" + + sensitivity_note = "" + if sensitivity < 30: + sensitivity_note = " Be lenient - only flag explicit content." + elif sensitivity > 70: + sensitivity_note = " Be strict - flag suggestive content." + + try: + response = await self.client.chat.completions.create( + model="gpt-4o-mini", # Use vision-capable model + max_tokens=500, + messages=[ + {"role": "system", "content": system + sensitivity_note}, + { + "role": "user", + "content": [ + {"type": "text", "text": "Analyze this image for moderation."}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + }, + ], + response_format={"type": "json_object"}, + ) + + data = self._parse_json_response(response.choices[0].message.content or "{}") + + return ImageAnalysisResult( + is_nsfw=data.get("is_nsfw", False), + is_violent=data.get("is_violent", False), + is_disturbing=data.get("is_disturbing", False), + confidence=float(data.get("confidence", 0.0)), + description=data.get("description", ""), + categories=data.get("categories", []), + ) + + except Exception as e: + logger.error(f"Error analyzing image: {e}") + return ImageAnalysisResult(description=f"Error analyzing image: {str(e)}") + + async def analyze_phishing( + self, + url: str, + message_content: str | None = None, + ) -> PhishingAnalysisResult: + """Analyze a URL for phishing/scam indicators.""" + system = """Analyze the URL for phishing/scam indicators. Respond in JSON: +{ + "is_phishing": true/false, + "confidence": 0.0-1.0, + "risk_factors": ["factor1"], + "explanation": "Brief explanation" +} + +Check for: domain impersonation, urgency tactics, credential requests, too-good-to-be-true offers.""" + + user_message = f"URL: {url}" + if message_content: + user_message += f"\n\nMessage context: {message_content}" + + try: + response = await self._call_api(system, user_message) + data = self._parse_json_response(response) + + return PhishingAnalysisResult( + is_phishing=data.get("is_phishing", False), + confidence=float(data.get("confidence", 0.0)), + risk_factors=data.get("risk_factors", []), + explanation=data.get("explanation", ""), + ) + + except Exception as e: + logger.error(f"Error analyzing phishing: {e}") + return PhishingAnalysisResult(explanation=f"Error analyzing URL: {str(e)}") + + async def close(self) -> None: + """Clean up resources.""" + await self.client.close() diff --git a/src/guardden/services/automod.py b/src/guardden/services/automod.py new file mode 100644 index 0000000..bcdecec --- /dev/null +++ b/src/guardden/services/automod.py @@ -0,0 +1,301 @@ +"""Automod service for content filtering and spam detection.""" + +import logging +import re +from collections import defaultdict +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from typing import NamedTuple + +import discord + +from guardden.models import BannedWord + +logger = logging.getLogger(__name__) + + +# Known scam/phishing patterns +SCAM_PATTERNS = [ + # Discord scam patterns + r"discord(?:[-.]?(?:gift|nitro|free|claim|steam))[\w.-]*\.(?!com|gg)[a-z]{2,}", + r"(?:free|claim|get)[-.\s]?(?:discord[-.\s]?)?nitro", + r"(?:steam|discord)[-.\s]?community[-.\s]?(?:giveaway|gift)", + # Generic phishing + r"(?:verify|confirm)[-.\s]?(?:your)?[-.\s]?account", + r"(?:suspended|locked|limited)[-.\s]?account", + r"click[-.\s]?(?:here|this)[-.\s]?(?:to[-.\s]?)?(?:verify|claim|get)", + # Crypto scams + r"(?:free|claim|airdrop)[-.\s]?(?:crypto|bitcoin|eth|nft)", + r"(?:double|2x)[-.\s]?your[-.\s]?(?:crypto|bitcoin|eth)", +] + +# Suspicious TLDs often used in phishing +SUSPICIOUS_TLDS = { + ".xyz", + ".top", + ".club", + ".work", + ".click", + ".link", + ".info", + ".ru", + ".cn", + ".tk", + ".ml", + ".ga", + ".cf", + ".gq", +} + +# URL pattern for extraction +URL_PATTERN = re.compile( + r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+[^\s]*|" + r"(?:www\.)?[-\w]+\.(?:com|org|net|io|gg|co|me|tv|xyz|top|club|work|click|link|info|ru|cn)[^\s]*", + re.IGNORECASE, +) + + +class SpamRecord(NamedTuple): + """Record of a message for spam tracking.""" + + content_hash: str + timestamp: datetime + + +@dataclass +class UserSpamTracker: + """Tracks spam behavior for a single user.""" + + messages: list[SpamRecord] = field(default_factory=list) + mention_count: int = 0 + last_mention_time: datetime | None = None + duplicate_count: int = 0 + last_action_time: datetime | None = None + + def cleanup(self, max_age: timedelta = timedelta(minutes=1)) -> None: + """Remove old messages from tracking.""" + cutoff = datetime.now(timezone.utc) - max_age + self.messages = [m for m in self.messages if m.timestamp > cutoff] + + +@dataclass +class AutomodResult: + """Result of automod check.""" + + should_delete: bool = False + should_warn: bool = False + should_strike: bool = False + should_timeout: bool = False + timeout_duration: int = 0 # seconds + reason: str = "" + matched_filter: str = "" + + +class AutomodService: + """Service for automatic content moderation.""" + + def __init__(self) -> None: + # Compile scam patterns + self._scam_patterns = [re.compile(p, re.IGNORECASE) for p in SCAM_PATTERNS] + + # Per-guild, per-user spam tracking + # Structure: {guild_id: {user_id: UserSpamTracker}} + self._spam_trackers: dict[int, dict[int, UserSpamTracker]] = defaultdict( + lambda: defaultdict(UserSpamTracker) + ) + + # Spam thresholds + self.message_rate_limit = 5 # messages per window + self.message_rate_window = 5 # seconds + self.duplicate_threshold = 3 # same message count + self.mention_limit = 5 # mentions per message + self.mention_rate_limit = 10 # mentions per window + self.mention_rate_window = 60 # seconds + + def _get_content_hash(self, content: str) -> str: + """Get a normalized hash of message content for duplicate detection.""" + # Normalize: lowercase, remove extra spaces, remove special chars + normalized = re.sub(r"[^\w\s]", "", content.lower()) + normalized = re.sub(r"\s+", " ", normalized).strip() + return normalized + + def check_banned_words( + self, content: str, banned_words: list[BannedWord] + ) -> AutomodResult | None: + """Check message against banned words list.""" + content_lower = content.lower() + + for banned in banned_words: + matched = False + + if banned.is_regex: + try: + if re.search(banned.pattern, content, re.IGNORECASE): + matched = True + except re.error: + logger.warning(f"Invalid regex pattern: {banned.pattern}") + continue + else: + if banned.pattern.lower() in content_lower: + matched = True + + if matched: + result = AutomodResult( + should_delete=True, + reason=banned.reason or f"Matched banned word filter", + matched_filter=f"banned_word:{banned.id}", + ) + + if banned.action == "warn": + result.should_warn = True + elif banned.action == "strike": + result.should_strike = True + + return result + + return None + + def check_scam_links(self, content: str) -> AutomodResult | None: + """Check message for scam/phishing patterns.""" + # Check for known scam patterns + for pattern in self._scam_patterns: + if pattern.search(content): + return AutomodResult( + should_delete=True, + should_warn=True, + reason="Message matched known scam/phishing pattern", + matched_filter="scam_pattern", + ) + + # Check URLs for suspicious TLDs + urls = URL_PATTERN.findall(content) + for url in urls: + url_lower = url.lower() + for tld in SUSPICIOUS_TLDS: + if tld in url_lower: + # Additional check: is it trying to impersonate a known domain? + impersonation_keywords = [ + "discord", + "steam", + "nitro", + "gift", + "free", + "login", + "verify", + ] + if any(kw in url_lower for kw in impersonation_keywords): + return AutomodResult( + should_delete=True, + should_warn=True, + reason=f"Suspicious link detected: {url[:50]}", + matched_filter="suspicious_link", + ) + + return None + + def check_spam( + self, message: discord.Message, anti_spam_enabled: bool = True + ) -> AutomodResult | None: + """Check message for spam behavior.""" + if not anti_spam_enabled: + return None + + guild_id = message.guild.id + user_id = message.author.id + tracker = self._spam_trackers[guild_id][user_id] + now = datetime.now(timezone.utc) + + # Cleanup old records + tracker.cleanup() + + # Check message rate + content_hash = self._get_content_hash(message.content) + tracker.messages.append(SpamRecord(content_hash, now)) + + # Rate limit check + recent_window = now - timedelta(seconds=self.message_rate_window) + recent_messages = [m for m in tracker.messages if m.timestamp > recent_window] + + if len(recent_messages) > self.message_rate_limit: + return AutomodResult( + should_delete=True, + should_timeout=True, + timeout_duration=60, # 1 minute timeout + reason=f"Sending messages too fast ({len(recent_messages)} in {self.message_rate_window}s)", + matched_filter="rate_limit", + ) + + # Duplicate message check + duplicate_count = sum(1 for m in tracker.messages if m.content_hash == content_hash) + if duplicate_count >= self.duplicate_threshold: + return AutomodResult( + should_delete=True, + should_warn=True, + reason=f"Duplicate message detected ({duplicate_count} times)", + matched_filter="duplicate", + ) + + # Mass mention check + mention_count = len(message.mentions) + len(message.role_mentions) + if message.mention_everyone: + mention_count += 100 # Treat @everyone as many mentions + + if mention_count > self.mention_limit: + return AutomodResult( + should_delete=True, + should_timeout=True, + timeout_duration=300, # 5 minute timeout + reason=f"Mass mentions detected ({mention_count} mentions)", + matched_filter="mass_mention", + ) + + return None + + def check_invite_links(self, content: str, allow_invites: bool = True) -> AutomodResult | None: + """Check for Discord invite links.""" + if allow_invites: + return None + + invite_pattern = re.compile( + r"(?:https?://)?(?:www\.)?(?:discord\.(?:gg|io|me|li)|discordapp\.com/invite)/[\w-]+", + re.IGNORECASE, + ) + + if invite_pattern.search(content): + return AutomodResult( + should_delete=True, + reason="Discord invite links are not allowed", + matched_filter="invite_link", + ) + + return None + + def check_all_caps( + self, content: str, threshold: float = 0.7, min_length: int = 10 + ) -> AutomodResult | None: + """Check for excessive caps usage.""" + # Only check messages with enough letters + letters = [c for c in content if c.isalpha()] + if len(letters) < min_length: + return None + + caps_count = sum(1 for c in letters if c.isupper()) + caps_ratio = caps_count / len(letters) + + if caps_ratio > threshold: + return AutomodResult( + should_delete=True, + reason="Excessive caps usage", + matched_filter="caps", + ) + + return None + + def reset_user_tracker(self, guild_id: int, user_id: int) -> None: + """Reset spam tracking for a user.""" + if guild_id in self._spam_trackers: + self._spam_trackers[guild_id].pop(user_id, None) + + def cleanup_guild(self, guild_id: int) -> None: + """Remove all tracking data for a guild.""" + self._spam_trackers.pop(guild_id, None) diff --git a/src/guardden/services/database.py b/src/guardden/services/database.py new file mode 100644 index 0000000..c3cd3c8 --- /dev/null +++ b/src/guardden/services/database.py @@ -0,0 +1,99 @@ +"""Database connection and session management.""" + +import logging +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +import asyncpg +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from guardden.config import Settings +from guardden.models import Base + +logger = logging.getLogger(__name__) + + +class Database: + """Manages database connections and sessions.""" + + def __init__(self, settings: Settings) -> None: + self.settings = settings + self._engine = None + self._session_factory = None + self._pool: asyncpg.Pool | None = None + + async def connect(self) -> None: + """Initialize database connection pool.""" + db_url = self.settings.database_url.get_secret_value() + + # Create SQLAlchemy async engine + # Convert postgresql:// to postgresql+asyncpg:// + if db_url.startswith("postgresql://"): + sqlalchemy_url = db_url.replace("postgresql://", "postgresql+asyncpg://", 1) + else: + sqlalchemy_url = db_url + + self._engine = create_async_engine( + sqlalchemy_url, + pool_size=self.settings.database_pool_min, + max_overflow=self.settings.database_pool_max - self.settings.database_pool_min, + echo=self.settings.log_level == "DEBUG", + ) + + self._session_factory = async_sessionmaker( + self._engine, + class_=AsyncSession, + expire_on_commit=False, + ) + + # Also create a raw asyncpg pool for performance-critical operations + self._pool = await asyncpg.create_pool( + db_url, + min_size=self.settings.database_pool_min, + max_size=self.settings.database_pool_max, + ) + + logger.info("Database connection established") + + async def disconnect(self) -> None: + """Close all database connections.""" + if self._pool: + await self._pool.close() + self._pool = None + + if self._engine: + await self._engine.dispose() + self._engine = None + + logger.info("Database connections closed") + + async def create_tables(self) -> None: + """Create all database tables.""" + if not self._engine: + raise RuntimeError("Database not connected") + + async with self._engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + logger.info("Database tables created") + + @asynccontextmanager + async def session(self) -> AsyncGenerator[AsyncSession, None]: + """Get a database session context manager.""" + if not self._session_factory: + raise RuntimeError("Database not connected") + + async with self._session_factory() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + + @property + def pool(self) -> asyncpg.Pool: + """Get the raw asyncpg connection pool.""" + if not self._pool: + raise RuntimeError("Database not connected") + return self._pool diff --git a/src/guardden/services/guild_config.py b/src/guardden/services/guild_config.py new file mode 100644 index 0000000..de36c67 --- /dev/null +++ b/src/guardden/services/guild_config.py @@ -0,0 +1,145 @@ +"""Guild configuration service.""" + +import logging +from functools import lru_cache + +import discord +from sqlalchemy import select +from sqlalchemy.orm import selectinload + +from guardden.models import BannedWord, Guild, GuildSettings +from guardden.services.database import Database + +logger = logging.getLogger(__name__) + + +class GuildConfigService: + """Manages guild configurations with caching.""" + + def __init__(self, database: Database) -> None: + self.database = database + self._cache: dict[int, GuildSettings] = {} + + async def get_config(self, guild_id: int) -> GuildSettings | None: + """Get guild configuration, using cache if available.""" + if guild_id in self._cache: + return self._cache[guild_id] + + async with self.database.session() as session: + result = await session.execute( + select(GuildSettings).where(GuildSettings.guild_id == guild_id) + ) + settings = result.scalar_one_or_none() + + if settings: + self._cache[guild_id] = settings + + return settings + + async def get_guild(self, guild_id: int) -> Guild | None: + """Get full guild data including settings and banned words.""" + async with self.database.session() as session: + result = await session.execute( + select(Guild) + .options(selectinload(Guild.settings), selectinload(Guild.banned_words)) + .where(Guild.id == guild_id) + ) + return result.scalar_one_or_none() + + async def create_guild(self, guild: discord.Guild) -> Guild: + """Create a new guild entry with default settings.""" + async with self.database.session() as session: + # Check if guild already exists + existing = await session.get(Guild, guild.id) + if existing: + return existing + + # Create new guild + db_guild = Guild( + id=guild.id, + name=guild.name, + owner_id=guild.owner_id, + ) + session.add(db_guild) + await session.flush() + + # Create default settings + settings = GuildSettings(guild_id=guild.id) + session.add(settings) + + await session.commit() + + logger.info(f"Created guild config for {guild.name} (ID: {guild.id})") + return db_guild + + async def update_settings(self, guild_id: int, **kwargs) -> GuildSettings | None: + """Update guild settings.""" + async with self.database.session() as session: + result = await session.execute( + select(GuildSettings).where(GuildSettings.guild_id == guild_id) + ) + settings = result.scalar_one_or_none() + + if not settings: + return None + + for key, value in kwargs.items(): + if hasattr(settings, key): + setattr(settings, key, value) + + await session.commit() + + # Invalidate cache + self._cache.pop(guild_id, None) + + return settings + + def invalidate_cache(self, guild_id: int) -> None: + """Remove a guild from the cache.""" + self._cache.pop(guild_id, None) + + async def get_banned_words(self, guild_id: int) -> list[BannedWord]: + """Get all banned words for a guild.""" + async with self.database.session() as session: + result = await session.execute( + select(BannedWord).where(BannedWord.guild_id == guild_id) + ) + return list(result.scalars().all()) + + async def add_banned_word( + self, + guild_id: int, + pattern: str, + added_by: int, + is_regex: bool = False, + action: str = "delete", + reason: str | None = None, + ) -> BannedWord: + """Add a banned word to a guild.""" + async with self.database.session() as session: + banned_word = BannedWord( + guild_id=guild_id, + pattern=pattern, + is_regex=is_regex, + action=action, + reason=reason, + added_by=added_by, + ) + session.add(banned_word) + await session.commit() + return banned_word + + async def remove_banned_word(self, guild_id: int, word_id: int) -> bool: + """Remove a banned word from a guild.""" + async with self.database.session() as session: + result = await session.execute( + select(BannedWord).where(BannedWord.id == word_id, BannedWord.guild_id == guild_id) + ) + word = result.scalar_one_or_none() + + if word: + session.delete(word) + await session.commit() + return True + + return False diff --git a/src/guardden/services/ratelimit.py b/src/guardden/services/ratelimit.py new file mode 100644 index 0000000..e4bdfc0 --- /dev/null +++ b/src/guardden/services/ratelimit.py @@ -0,0 +1,300 @@ +"""Rate limiting service for command and action throttling.""" + +import logging +from collections import defaultdict +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from enum import Enum +from typing import Callable + +logger = logging.getLogger(__name__) + + +class RateLimitScope(str, Enum): + """Scope of rate limiting.""" + + USER = "user" # Per user globally + MEMBER = "member" # Per user per guild + CHANNEL = "channel" # Per channel + GUILD = "guild" # Per guild + + +@dataclass +class RateLimitBucket: + """Tracks rate limit state for a single bucket.""" + + max_requests: int + window_seconds: float + requests: list[datetime] = field(default_factory=list) + + def cleanup(self) -> None: + """Remove expired requests from tracking.""" + cutoff = datetime.now(timezone.utc) - timedelta(seconds=self.window_seconds) + self.requests = [r for r in self.requests if r > cutoff] + + def is_limited(self) -> bool: + """Check if this bucket is rate limited.""" + self.cleanup() + return len(self.requests) >= self.max_requests + + def record(self) -> None: + """Record a request.""" + self.requests.append(datetime.now(timezone.utc)) + + def remaining(self) -> int: + """Get remaining requests in current window.""" + self.cleanup() + return max(0, self.max_requests - len(self.requests)) + + def reset_after(self) -> float: + """Get seconds until rate limit resets.""" + if not self.requests: + return 0.0 + self.cleanup() + if not self.requests: + return 0.0 + oldest = min(self.requests) + reset_time = oldest + timedelta(seconds=self.window_seconds) + remaining = (reset_time - datetime.now(timezone.utc)).total_seconds() + return max(0.0, remaining) + + +@dataclass +class RateLimitConfig: + """Configuration for a rate limit.""" + + max_requests: int + window_seconds: float + scope: RateLimitScope = RateLimitScope.MEMBER + + def create_bucket(self) -> RateLimitBucket: + return RateLimitBucket( + max_requests=self.max_requests, + window_seconds=self.window_seconds, + ) + + +@dataclass +class RateLimitResult: + """Result of a rate limit check.""" + + is_limited: bool + remaining: int + reset_after: float + bucket_key: str + + +class RateLimiter: + """General-purpose rate limiter.""" + + # Default rate limits for various actions + DEFAULT_LIMITS = { + "command": RateLimitConfig(5, 10, RateLimitScope.MEMBER), # 5 commands per 10s + "moderation": RateLimitConfig(10, 60, RateLimitScope.MEMBER), # 10 mod actions per minute + "verification": RateLimitConfig(3, 300, RateLimitScope.MEMBER), # 3 verifications per 5 min + "message": RateLimitConfig(10, 10, RateLimitScope.MEMBER), # 10 messages per 10s + "api_call": RateLimitConfig( + 30, 60, RateLimitScope.GUILD + ), # 30 API calls per minute per guild + } + + def __init__(self) -> None: + # Buckets: {action: {bucket_key: RateLimitBucket}} + self._buckets: dict[str, dict[str, RateLimitBucket]] = defaultdict(dict) + self._configs: dict[str, RateLimitConfig] = dict(self.DEFAULT_LIMITS) + + def configure(self, action: str, config: RateLimitConfig) -> None: + """Configure rate limit for an action.""" + self._configs[action] = config + # Clear existing buckets for this action + self._buckets[action].clear() + + def _get_bucket_key( + self, + scope: RateLimitScope, + user_id: int | None = None, + guild_id: int | None = None, + channel_id: int | None = None, + ) -> str: + """Generate a bucket key based on scope.""" + if scope == RateLimitScope.USER: + return f"user:{user_id}" + elif scope == RateLimitScope.MEMBER: + return f"member:{guild_id}:{user_id}" + elif scope == RateLimitScope.CHANNEL: + return f"channel:{channel_id}" + elif scope == RateLimitScope.GUILD: + return f"guild:{guild_id}" + return f"unknown:{user_id}:{guild_id}" + + def check( + self, + action: str, + user_id: int | None = None, + guild_id: int | None = None, + channel_id: int | None = None, + ) -> RateLimitResult: + """ + Check if an action is rate limited. + + Does not record the request - use `acquire()` for that. + """ + config = self._configs.get(action) + if not config: + return RateLimitResult( + is_limited=False, + remaining=999, + reset_after=0, + bucket_key="", + ) + + bucket_key = self._get_bucket_key(config.scope, user_id, guild_id, channel_id) + bucket = self._buckets[action].get(bucket_key) + + if not bucket: + return RateLimitResult( + is_limited=False, + remaining=config.max_requests, + reset_after=0, + bucket_key=bucket_key, + ) + + return RateLimitResult( + is_limited=bucket.is_limited(), + remaining=bucket.remaining(), + reset_after=bucket.reset_after(), + bucket_key=bucket_key, + ) + + def acquire( + self, + action: str, + user_id: int | None = None, + guild_id: int | None = None, + channel_id: int | None = None, + ) -> RateLimitResult: + """ + Attempt to acquire a rate limit slot. + + Records the request if not limited. + """ + config = self._configs.get(action) + if not config: + return RateLimitResult( + is_limited=False, + remaining=999, + reset_after=0, + bucket_key="", + ) + + bucket_key = self._get_bucket_key(config.scope, user_id, guild_id, channel_id) + + if bucket_key not in self._buckets[action]: + self._buckets[action][bucket_key] = config.create_bucket() + + bucket = self._buckets[action][bucket_key] + + if bucket.is_limited(): + return RateLimitResult( + is_limited=True, + remaining=0, + reset_after=bucket.reset_after(), + bucket_key=bucket_key, + ) + + bucket.record() + + return RateLimitResult( + is_limited=False, + remaining=bucket.remaining(), + reset_after=bucket.reset_after(), + bucket_key=bucket_key, + ) + + def reset( + self, + action: str, + user_id: int | None = None, + guild_id: int | None = None, + channel_id: int | None = None, + ) -> bool: + """Reset rate limit for a specific bucket.""" + config = self._configs.get(action) + if not config: + return False + + bucket_key = self._get_bucket_key(config.scope, user_id, guild_id, channel_id) + return self._buckets[action].pop(bucket_key, None) is not None + + def cleanup(self) -> int: + """Clean up empty and expired buckets. Returns count removed.""" + removed = 0 + for action in list(self._buckets.keys()): + for key in list(self._buckets[action].keys()): + bucket = self._buckets[action][key] + bucket.cleanup() + if not bucket.requests: + del self._buckets[action][key] + removed += 1 + return removed + + +# Global rate limiter instance +_rate_limiter: RateLimiter | None = None + + +def get_rate_limiter() -> RateLimiter: + """Get or create the global rate limiter instance.""" + global _rate_limiter + if _rate_limiter is None: + _rate_limiter = RateLimiter() + return _rate_limiter + + +def ratelimit( + action: str = "command", + max_requests: int | None = None, + window_seconds: float | None = None, +) -> Callable: + """ + Decorator for rate limiting commands. + + Usage: + @ratelimit("moderation", max_requests=5, window_seconds=60) + async def my_command(self, ctx): + ... + """ + + def decorator(func: Callable) -> Callable: + async def wrapper(self, ctx, *args, **kwargs): + limiter = get_rate_limiter() + + # Configure if custom limits provided + if max_requests is not None and window_seconds is not None: + limiter.configure( + action, + RateLimitConfig(max_requests, window_seconds, RateLimitScope.MEMBER), + ) + + result = limiter.acquire( + action, + user_id=ctx.author.id, + guild_id=ctx.guild.id if ctx.guild else None, + channel_id=ctx.channel.id, + ) + + if result.is_limited: + await ctx.send( + f"You're being rate limited. Try again in {result.reset_after:.1f} seconds.", + delete_after=5, + ) + return + + return await func(self, ctx, *args, **kwargs) + + # Preserve function metadata + wrapper.__name__ = func.__name__ + wrapper.__doc__ = func.__doc__ + return wrapper + + return decorator diff --git a/src/guardden/services/verification.py b/src/guardden/services/verification.py new file mode 100644 index 0000000..2b92a43 --- /dev/null +++ b/src/guardden/services/verification.py @@ -0,0 +1,300 @@ +"""Verification service for new member challenges.""" + +import asyncio +import logging +import random +import string +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from enum import Enum +from typing import Any + +import discord + +logger = logging.getLogger(__name__) + + +class ChallengeType(str, Enum): + """Types of verification challenges.""" + + BUTTON = "button" # Simple button click + CAPTCHA = "captcha" # Text-based captcha + MATH = "math" # Simple math problem + EMOJI = "emoji" # Select correct emoji + QUESTIONS = "questions" # Custom questions + + +@dataclass +class Challenge: + """Represents a verification challenge.""" + + challenge_type: ChallengeType + question: str + answer: str + options: list[str] = field(default_factory=list) # For multiple choice + expires_at: datetime = field( + default_factory=lambda: datetime.now(timezone.utc) + timedelta(minutes=10) + ) + attempts: int = 0 + max_attempts: int = 3 + + @property + def is_expired(self) -> bool: + return datetime.now(timezone.utc) > self.expires_at + + def check_answer(self, response: str) -> bool: + """Check if the response is correct.""" + self.attempts += 1 + return response.strip().lower() == self.answer.lower() + + +@dataclass +class PendingVerification: + """Tracks a pending verification for a user.""" + + user_id: int + guild_id: int + challenge: Challenge + message_id: int | None = None + channel_id: int | None = None + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + +class ChallengeGenerator(ABC): + """Abstract base class for challenge generators.""" + + @abstractmethod + def generate(self) -> Challenge: + """Generate a new challenge.""" + pass + + +class ButtonChallengeGenerator(ChallengeGenerator): + """Generates simple button click challenges.""" + + def generate(self) -> Challenge: + return Challenge( + challenge_type=ChallengeType.BUTTON, + question="Click the button below to verify you're human.", + answer="verified", + ) + + +class CaptchaChallengeGenerator(ChallengeGenerator): + """Generates text-based captcha challenges.""" + + def __init__(self, length: int = 6) -> None: + self.length = length + + def generate(self) -> Challenge: + # Generate random alphanumeric code (avoiding confusing chars) + chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" + code = "".join(random.choices(chars, k=self.length)) + + # Create visual representation with some obfuscation + visual = self._create_visual(code) + + return Challenge( + challenge_type=ChallengeType.CAPTCHA, + question=f"Enter the code shown below:\n```\n{visual}\n```", + answer=code, + ) + + def _create_visual(self, code: str) -> str: + """Create a simple text-based visual captcha.""" + lines = [] + # Add some noise characters + noise_chars = ".-*~^" + + for _ in range(2): + lines.append("".join(random.choices(noise_chars, k=len(code) * 2))) + + # Add the code with spacing + spaced = " ".join(code) + lines.append(spaced) + + for _ in range(2): + lines.append("".join(random.choices(noise_chars, k=len(code) * 2))) + + return "\n".join(lines) + + +class MathChallengeGenerator(ChallengeGenerator): + """Generates simple math problem challenges.""" + + def generate(self) -> Challenge: + # Generate simple addition/subtraction/multiplication + operation = random.choice(["+", "-", "*"]) + + if operation == "*": + a = random.randint(2, 10) + b = random.randint(2, 10) + else: + a = random.randint(10, 50) + b = random.randint(1, 20) + + if operation == "+": + answer = a + b + elif operation == "-": + # Ensure positive result + if b > a: + a, b = b, a + answer = a - b + else: + answer = a * b + + return Challenge( + challenge_type=ChallengeType.MATH, + question=f"Solve this math problem: **{a} {operation} {b} = ?**", + answer=str(answer), + ) + + +class EmojiChallengeGenerator(ChallengeGenerator): + """Generates emoji selection challenges.""" + + EMOJI_SETS = [ + ("animals", ["🐶", "🐱", "🐭", "🐹", "🐰", "🦊", "🐻", "🐼"]), + ("fruits", ["🍎", "🍐", "🍊", "🍋", "🍌", "🍉", "🍇", "🍓"]), + ("weather", ["☀️", "🌙", "⭐", "🌧️", "❄️", "🌈", "⚡", "🌪️"]), + ("sports", ["⚽", "🏀", "🏈", "⚾", "🎾", "🏐", "🏉", "🎱"]), + ] + + def generate(self) -> Challenge: + category, emojis = random.choice(self.EMOJI_SETS) + target = random.choice(emojis) + + # Create options with the target and some others + options = [target] + other_emojis = [e for e in emojis if e != target] + options.extend(random.sample(other_emojis, min(3, len(other_emojis)))) + random.shuffle(options) + + return Challenge( + challenge_type=ChallengeType.EMOJI, + question=f"Select the {self._emoji_name(target)} emoji:", + answer=target, + options=options, + ) + + def _emoji_name(self, emoji: str) -> str: + """Get a description of the emoji.""" + names = { + "🐶": "dog", + "🐱": "cat", + "🐭": "mouse", + "🐹": "hamster", + "🐰": "rabbit", + "🦊": "fox", + "🐻": "bear", + "🐼": "panda", + "🍎": "apple", + "🍐": "pear", + "🍊": "orange", + "🍋": "lemon", + "🍌": "banana", + "🍉": "watermelon", + "🍇": "grapes", + "🍓": "strawberry", + "☀️": "sun", + "🌙": "moon", + "⭐": "star", + "🌧️": "rain", + "❄️": "snowflake", + "🌈": "rainbow", + "⚡": "lightning", + "🌪️": "tornado", + "⚽": "soccer ball", + "🏀": "basketball", + "🏈": "football", + "⚾": "baseball", + "🎾": "tennis", + "🏐": "volleyball", + "🏉": "rugby", + "🎱": "pool ball", + } + return names.get(emoji, "correct") + + +class VerificationService: + """Service for managing member verification.""" + + def __init__(self) -> None: + # Pending verifications: {(guild_id, user_id): PendingVerification} + self._pending: dict[tuple[int, int], PendingVerification] = {} + + # Challenge generators + self._generators: dict[ChallengeType, ChallengeGenerator] = { + ChallengeType.BUTTON: ButtonChallengeGenerator(), + ChallengeType.CAPTCHA: CaptchaChallengeGenerator(), + ChallengeType.MATH: MathChallengeGenerator(), + ChallengeType.EMOJI: EmojiChallengeGenerator(), + } + + def create_challenge( + self, + user_id: int, + guild_id: int, + challenge_type: ChallengeType = ChallengeType.BUTTON, + ) -> PendingVerification: + """Create a new verification challenge for a user.""" + generator = self._generators.get(challenge_type) + if not generator: + generator = self._generators[ChallengeType.BUTTON] + + challenge = generator.generate() + pending = PendingVerification( + user_id=user_id, + guild_id=guild_id, + challenge=challenge, + ) + + self._pending[(guild_id, user_id)] = pending + return pending + + def get_pending(self, guild_id: int, user_id: int) -> PendingVerification | None: + """Get a pending verification for a user.""" + return self._pending.get((guild_id, user_id)) + + def verify(self, guild_id: int, user_id: int, response: str) -> tuple[bool, str]: + """ + Attempt to verify a user's response. + + Returns: + Tuple of (success, message) + """ + pending = self._pending.get((guild_id, user_id)) + + if not pending: + return False, "No pending verification found." + + if pending.challenge.is_expired: + self._pending.pop((guild_id, user_id), None) + return False, "Verification expired. Please request a new one." + + if pending.challenge.attempts >= pending.challenge.max_attempts: + self._pending.pop((guild_id, user_id), None) + return False, "Too many failed attempts. Please request a new verification." + + if pending.challenge.check_answer(response): + self._pending.pop((guild_id, user_id), None) + return True, "Verification successful!" + + remaining = pending.challenge.max_attempts - pending.challenge.attempts + return False, f"Incorrect. {remaining} attempt(s) remaining." + + def cancel(self, guild_id: int, user_id: int) -> bool: + """Cancel a pending verification.""" + return self._pending.pop((guild_id, user_id), None) is not None + + def cleanup_expired(self) -> int: + """Remove expired verifications. Returns count of removed.""" + expired = [key for key, pending in self._pending.items() if pending.challenge.is_expired] + for key in expired: + self._pending.pop(key, None) + return len(expired) + + def get_pending_count(self, guild_id: int) -> int: + """Get count of pending verifications for a guild.""" + return sum(1 for (gid, _) in self._pending if gid == guild_id) diff --git a/src/guardden/utils/__init__.py b/src/guardden/utils/__init__.py new file mode 100644 index 0000000..e47c29a --- /dev/null +++ b/src/guardden/utils/__init__.py @@ -0,0 +1,5 @@ +"""Utility functions for GuardDen.""" + +from guardden.utils.logging import setup_logging + +__all__ = ["setup_logging"] diff --git a/src/guardden/utils/logging.py b/src/guardden/utils/logging.py new file mode 100644 index 0000000..289f5c1 --- /dev/null +++ b/src/guardden/utils/logging.py @@ -0,0 +1,27 @@ +"""Logging configuration for GuardDen.""" + +import logging +import sys +from typing import Literal + + +def setup_logging(level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO") -> None: + """Configure logging for the application.""" + log_format = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s" + date_format = "%Y-%m-%d %H:%M:%S" + + # Configure root logger + logging.basicConfig( + level=getattr(logging, level), + format=log_format, + datefmt=date_format, + handlers=[logging.StreamHandler(sys.stdout)], + ) + + # Reduce noise from third-party libraries + logging.getLogger("discord").setLevel(logging.WARNING) + logging.getLogger("discord.http").setLevel(logging.WARNING) + logging.getLogger("asyncio").setLevel(logging.WARNING) + logging.getLogger("sqlalchemy.engine").setLevel( + logging.DEBUG if level == "DEBUG" else logging.WARNING + ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..d8623d5 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for GuardDen.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..2b2d6a2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,15 @@ +"""Pytest fixtures for GuardDen tests.""" + +import pytest + + +@pytest.fixture +def sample_guild_id() -> int: + """Return a sample Discord guild ID.""" + return 123456789012345678 + + +@pytest.fixture +def sample_user_id() -> int: + """Return a sample Discord user ID.""" + return 987654321098765432 diff --git a/tests/test_ai.py b/tests/test_ai.py new file mode 100644 index 0000000..21a9a3b --- /dev/null +++ b/tests/test_ai.py @@ -0,0 +1,119 @@ +"""Tests for AI services.""" + +import pytest + +from guardden.services.ai.base import ContentCategory, ModerationResult +from guardden.services.ai.factory import NullProvider, create_ai_provider + + +class TestModerationResult: + """Tests for ModerationResult dataclass.""" + + def test_severity_not_flagged(self) -> None: + """Test severity is 0 when not flagged.""" + result = ModerationResult(is_flagged=False, confidence=0.9) + assert result.severity == 0 + + def test_severity_with_confidence(self) -> None: + """Test severity includes confidence.""" + result = ModerationResult( + is_flagged=True, + confidence=0.8, + categories=[], + ) + # 0.8 * 50 = 40 + assert result.severity == 40 + + def test_severity_high_category(self) -> None: + """Test severity with high-severity category.""" + result = ModerationResult( + is_flagged=True, + confidence=0.5, + categories=[ContentCategory.HATE_SPEECH], + ) + # 0.5 * 50 + 30 = 55 + assert result.severity == 55 + + def test_severity_medium_category(self) -> None: + """Test severity with medium-severity category.""" + result = ModerationResult( + is_flagged=True, + confidence=0.5, + categories=[ContentCategory.HARASSMENT], + ) + # 0.5 * 50 + 20 = 45 + assert result.severity == 45 + + def test_severity_multiple_categories(self) -> None: + """Test severity with multiple categories.""" + result = ModerationResult( + is_flagged=True, + confidence=0.5, + categories=[ContentCategory.HATE_SPEECH, ContentCategory.VIOLENCE], + ) + # 0.5 * 50 + 30 + 20 = 75 + assert result.severity == 75 + + def test_severity_capped_at_100(self) -> None: + """Test severity is capped at 100.""" + result = ModerationResult( + is_flagged=True, + confidence=1.0, + categories=[ + ContentCategory.HATE_SPEECH, + ContentCategory.SELF_HARM, + ContentCategory.SCAM, + ], + ) + # Would be 50 + 30 + 30 + 30 = 140, capped to 100 + assert result.severity == 100 + + +class TestNullProvider: + """Tests for NullProvider.""" + + @pytest.fixture + def provider(self) -> NullProvider: + return NullProvider() + + @pytest.mark.asyncio + async def test_moderate_text_returns_empty(self, provider: NullProvider) -> None: + """Test moderate_text returns unflagged result.""" + result = await provider.moderate_text("test content") + assert result.is_flagged is False + + @pytest.mark.asyncio + async def test_analyze_image_returns_empty(self, provider: NullProvider) -> None: + """Test analyze_image returns empty result.""" + result = await provider.analyze_image("http://example.com/image.jpg") + assert result.is_nsfw is False + + @pytest.mark.asyncio + async def test_analyze_phishing_returns_empty(self, provider: NullProvider) -> None: + """Test analyze_phishing returns empty result.""" + result = await provider.analyze_phishing("http://example.com") + assert result.is_phishing is False + + +class TestFactory: + """Tests for AI provider factory.""" + + def test_create_null_provider(self) -> None: + """Test creating null provider.""" + provider = create_ai_provider("none") + assert isinstance(provider, NullProvider) + + def test_create_anthropic_without_key(self) -> None: + """Test creating anthropic provider without key raises error.""" + with pytest.raises(ValueError, match="API key required"): + create_ai_provider("anthropic", None) + + def test_create_openai_without_key(self) -> None: + """Test creating openai provider without key raises error.""" + with pytest.raises(ValueError, match="API key required"): + create_ai_provider("openai", None) + + def test_create_unknown_provider(self) -> None: + """Test creating unknown provider raises error.""" + with pytest.raises(ValueError, match="Unknown AI provider"): + create_ai_provider("unknown", "key") # type: ignore diff --git a/tests/test_automod.py b/tests/test_automod.py new file mode 100644 index 0000000..b17b5ab --- /dev/null +++ b/tests/test_automod.py @@ -0,0 +1,153 @@ +"""Tests for the automod service.""" + +import pytest + +from guardden.models import BannedWord +from guardden.services.automod import AutomodService + + +@pytest.fixture +def automod() -> AutomodService: + """Create an automod service instance.""" + return AutomodService() + + +class TestBannedWords: + """Tests for banned word filtering.""" + + def test_simple_match(self, automod: AutomodService) -> None: + """Test simple text matching.""" + banned = [_make_banned_word("badword")] + result = automod.check_banned_words("This contains badword in it", banned) + assert result is not None + assert result.should_delete + + def test_case_insensitive(self, automod: AutomodService) -> None: + """Test case insensitive matching.""" + banned = [_make_banned_word("BadWord")] + result = automod.check_banned_words("this contains BADWORD here", banned) + assert result is not None + + def test_no_match(self, automod: AutomodService) -> None: + """Test no match returns None.""" + banned = [_make_banned_word("badword")] + result = automod.check_banned_words("This is a clean message", banned) + assert result is None + + def test_regex_pattern(self, automod: AutomodService) -> None: + """Test regex pattern matching.""" + banned = [_make_banned_word(r"bad\w+", is_regex=True)] + result = automod.check_banned_words("This is badword and badstuff", banned) + assert result is not None + + def test_action_warn(self, automod: AutomodService) -> None: + """Test warn action is set.""" + banned = [_make_banned_word("badword", action="warn")] + result = automod.check_banned_words("badword", banned) + assert result is not None + assert result.should_warn + + def test_action_strike(self, automod: AutomodService) -> None: + """Test strike action is set.""" + banned = [_make_banned_word("badword", action="strike")] + result = automod.check_banned_words("badword", banned) + assert result is not None + assert result.should_strike + + +class TestScamDetection: + """Tests for scam/phishing detection.""" + + def test_discord_nitro_scam(self, automod: AutomodService) -> None: + """Test detection of fake Discord Nitro links.""" + result = automod.check_scam_links("Free nitro at discord-nitro.gift") + assert result is not None + assert result.should_delete + + def test_steam_scam(self, automod: AutomodService) -> None: + """Test detection of Steam scam patterns.""" + result = automod.check_scam_links("Check out this steam-community-giveaway.xyz") + assert result is not None + + def test_legitimate_discord_link(self, automod: AutomodService) -> None: + """Test that legitimate Discord links pass.""" + result = automod.check_scam_links("Join us at discord.gg/example") + assert result is None + + def test_suspicious_tld_with_keyword(self, automod: AutomodService) -> None: + """Test suspicious TLD with impersonation keyword.""" + result = automod.check_scam_links("Visit discord-verify.xyz to claim") + assert result is not None + + def test_normal_url(self, automod: AutomodService) -> None: + """Test normal URLs pass.""" + result = automod.check_scam_links("Check out https://github.com/example") + assert result is None + + +class TestInviteLinks: + """Tests for Discord invite link detection.""" + + def test_discord_gg_invite(self, automod: AutomodService) -> None: + """Test discord.gg invite detection.""" + result = automod.check_invite_links("Join discord.gg/example", allow_invites=False) + assert result is not None + assert result.should_delete + + def test_discordapp_invite(self, automod: AutomodService) -> None: + """Test discordapp.com invite detection.""" + result = automod.check_invite_links( + "Join https://discordapp.com/invite/abc123", allow_invites=False + ) + assert result is not None + + def test_allowed_invites(self, automod: AutomodService) -> None: + """Test invites pass when allowed.""" + result = automod.check_invite_links("Join discord.gg/example", allow_invites=True) + assert result is None + + +class TestCapsDetection: + """Tests for excessive caps detection.""" + + def test_excessive_caps(self, automod: AutomodService) -> None: + """Test detection of all caps message.""" + result = automod.check_all_caps("THIS IS ALL CAPS MESSAGE HERE") + assert result is not None + + def test_normal_caps(self, automod: AutomodService) -> None: + """Test normal message passes.""" + result = automod.check_all_caps("This is a Normal Message with Some Caps") + assert result is None + + def test_short_message_ignored(self, automod: AutomodService) -> None: + """Test short messages are ignored.""" + result = automod.check_all_caps("HI THERE") + assert result is None + + +class MockBannedWord: + """Mock BannedWord for testing without database.""" + + def __init__( + self, + pattern: str, + is_regex: bool = False, + action: str = "delete", + ) -> None: + self.id = 1 + self.guild_id = 123 + self.pattern = pattern + self.is_regex = is_regex + self.action = action + self.reason = None + self.added_by = 456 + + +def _make_banned_word( + pattern: str, + is_regex: bool = False, + action: str = "delete", +) -> MockBannedWord: + """Create a mock BannedWord object for testing.""" + return MockBannedWord(pattern, is_regex, action) diff --git a/tests/test_ratelimit.py b/tests/test_ratelimit.py new file mode 100644 index 0000000..e2acd5c --- /dev/null +++ b/tests/test_ratelimit.py @@ -0,0 +1,130 @@ +"""Tests for rate limiting service.""" + +import pytest + +from guardden.services.ratelimit import ( + RateLimitBucket, + RateLimitConfig, + RateLimiter, + RateLimitScope, +) + + +class TestRateLimitBucket: + """Tests for RateLimitBucket.""" + + def test_not_limited_initially(self) -> None: + """Test bucket is not limited when empty.""" + bucket = RateLimitBucket(max_requests=3, window_seconds=60) + assert bucket.is_limited() is False + assert bucket.remaining() == 3 + + def test_limited_after_max_requests(self) -> None: + """Test bucket is limited after max requests.""" + bucket = RateLimitBucket(max_requests=3, window_seconds=60) + + for _ in range(3): + bucket.record() + + assert bucket.is_limited() is True + assert bucket.remaining() == 0 + + def test_remaining_decreases(self) -> None: + """Test remaining count decreases.""" + bucket = RateLimitBucket(max_requests=5, window_seconds=60) + + bucket.record() + assert bucket.remaining() == 4 + + bucket.record() + assert bucket.remaining() == 3 + + +class TestRateLimiter: + """Tests for RateLimiter.""" + + @pytest.fixture + def limiter(self) -> RateLimiter: + return RateLimiter() + + def test_check_not_limited(self, limiter: RateLimiter) -> None: + """Test check returns not limited for new bucket.""" + result = limiter.check("command", user_id=123, guild_id=456) + assert result.is_limited is False + + def test_acquire_records_request(self, limiter: RateLimiter) -> None: + """Test acquire records the request.""" + # Configure a simple limit + limiter.configure( + "test", + RateLimitConfig(max_requests=2, window_seconds=60, scope=RateLimitScope.USER), + ) + + result1 = limiter.acquire("test", user_id=123) + assert result1.is_limited is False + assert result1.remaining == 1 + + result2 = limiter.acquire("test", user_id=123) + assert result2.is_limited is False + assert result2.remaining == 0 + + result3 = limiter.acquire("test", user_id=123) + assert result3.is_limited is True + + def test_different_scopes(self, limiter: RateLimiter) -> None: + """Test different scopes create different buckets.""" + # User scope - shared across guilds + limiter.configure( + "user_action", + RateLimitConfig(max_requests=1, window_seconds=60, scope=RateLimitScope.USER), + ) + + limiter.acquire("user_action", user_id=123, guild_id=1) + result = limiter.acquire("user_action", user_id=123, guild_id=2) + assert result.is_limited is True # Same user, different guild + + # Member scope - per guild + limiter.configure( + "member_action", + RateLimitConfig(max_requests=1, window_seconds=60, scope=RateLimitScope.MEMBER), + ) + + limiter.acquire("member_action", user_id=456, guild_id=1) + result = limiter.acquire("member_action", user_id=456, guild_id=2) + assert result.is_limited is False # Same user, different guild = different bucket + + def test_reset(self, limiter: RateLimiter) -> None: + """Test resetting a bucket.""" + limiter.configure( + "test", + RateLimitConfig(max_requests=1, window_seconds=60, scope=RateLimitScope.USER), + ) + + limiter.acquire("test", user_id=123) + assert limiter.acquire("test", user_id=123).is_limited is True + + limiter.reset("test", user_id=123) + assert limiter.acquire("test", user_id=123).is_limited is False + + def test_unknown_action(self, limiter: RateLimiter) -> None: + """Test unknown action returns not limited.""" + result = limiter.acquire("unknown_action", user_id=123) + assert result.is_limited is False + assert result.remaining == 999 + + def test_guild_scope(self, limiter: RateLimiter) -> None: + """Test guild-scoped rate limiting.""" + limiter.configure( + "guild_action", + RateLimitConfig(max_requests=2, window_seconds=60, scope=RateLimitScope.GUILD), + ) + + # Different users in same guild share the limit + limiter.acquire("guild_action", user_id=1, guild_id=100) + limiter.acquire("guild_action", user_id=2, guild_id=100) + result = limiter.acquire("guild_action", user_id=3, guild_id=100) + assert result.is_limited is True + + # Different guild is not limited + result = limiter.acquire("guild_action", user_id=1, guild_id=200) + assert result.is_limited is False diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..db10b2d --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,48 @@ +"""Tests for utility functions.""" + +from datetime import timedelta + +import pytest + +from guardden.cogs.moderation import parse_duration + + +class TestParseDuration: + """Tests for the parse_duration function.""" + + def test_parse_seconds(self) -> None: + """Test parsing seconds.""" + assert parse_duration("30s") == timedelta(seconds=30) + assert parse_duration("1s") == timedelta(seconds=1) + + def test_parse_minutes(self) -> None: + """Test parsing minutes.""" + assert parse_duration("5m") == timedelta(minutes=5) + assert parse_duration("30m") == timedelta(minutes=30) + + def test_parse_hours(self) -> None: + """Test parsing hours.""" + assert parse_duration("1h") == timedelta(hours=1) + assert parse_duration("24h") == timedelta(hours=24) + + def test_parse_days(self) -> None: + """Test parsing days.""" + assert parse_duration("7d") == timedelta(days=7) + assert parse_duration("1d") == timedelta(days=1) + + def test_parse_weeks(self) -> None: + """Test parsing weeks.""" + assert parse_duration("2w") == timedelta(weeks=2) + assert parse_duration("1w") == timedelta(weeks=1) + + def test_invalid_format(self) -> None: + """Test invalid duration formats.""" + assert parse_duration("invalid") is None + assert parse_duration("") is None + assert parse_duration("10") is None + assert parse_duration("abc") is None + + def test_case_insensitive(self) -> None: + """Test that parsing is case insensitive.""" + assert parse_duration("1H") == timedelta(hours=1) + assert parse_duration("30M") == timedelta(minutes=30) diff --git a/tests/test_verification.py b/tests/test_verification.py new file mode 100644 index 0000000..8cdd9fb --- /dev/null +++ b/tests/test_verification.py @@ -0,0 +1,142 @@ +"""Tests for verification service.""" + +import pytest + +from guardden.services.verification import ( + ButtonChallengeGenerator, + CaptchaChallengeGenerator, + ChallengeType, + EmojiChallengeGenerator, + MathChallengeGenerator, + VerificationService, +) + + +class TestChallengeGenerators: + """Tests for challenge generators.""" + + def test_button_challenge(self) -> None: + """Test button challenge generation.""" + gen = ButtonChallengeGenerator() + challenge = gen.generate() + + assert challenge.challenge_type == ChallengeType.BUTTON + assert challenge.answer == "verified" + + def test_captcha_challenge(self) -> None: + """Test captcha challenge generation.""" + gen = CaptchaChallengeGenerator(length=6) + challenge = gen.generate() + + assert challenge.challenge_type == ChallengeType.CAPTCHA + assert len(challenge.answer) == 6 + assert challenge.answer.isalnum() + + def test_math_challenge(self) -> None: + """Test math challenge generation.""" + gen = MathChallengeGenerator() + challenge = gen.generate() + + assert challenge.challenge_type == ChallengeType.MATH + # Answer should be a number + assert challenge.answer.lstrip("-").isdigit() + + def test_emoji_challenge(self) -> None: + """Test emoji challenge generation.""" + gen = EmojiChallengeGenerator() + challenge = gen.generate() + + assert challenge.challenge_type == ChallengeType.EMOJI + assert len(challenge.options) > 0 + assert challenge.answer in challenge.options + + +class TestVerificationService: + """Tests for verification service.""" + + @pytest.fixture + def service(self) -> VerificationService: + return VerificationService() + + def test_create_challenge(self, service: VerificationService) -> None: + """Test creating a challenge.""" + pending = service.create_challenge( + user_id=123, + guild_id=456, + challenge_type=ChallengeType.BUTTON, + ) + + assert pending.user_id == 123 + assert pending.guild_id == 456 + assert pending.challenge.challenge_type == ChallengeType.BUTTON + + def test_get_pending(self, service: VerificationService) -> None: + """Test retrieving pending verification.""" + service.create_challenge(123, 456, ChallengeType.BUTTON) + + pending = service.get_pending(456, 123) + assert pending is not None + assert pending.user_id == 123 + + # Non-existent should return None + assert service.get_pending(456, 999) is None + + def test_verify_correct(self, service: VerificationService) -> None: + """Test successful verification.""" + service.create_challenge(123, 456, ChallengeType.BUTTON) + + success, message = service.verify(456, 123, "verified") + assert success is True + assert "successful" in message.lower() + + # Should be removed after success + assert service.get_pending(456, 123) is None + + def test_verify_incorrect(self, service: VerificationService) -> None: + """Test failed verification.""" + service.create_challenge(123, 456, ChallengeType.BUTTON) + + success, message = service.verify(456, 123, "wrong") + assert success is False + assert "incorrect" in message.lower() + + # Should still exist + assert service.get_pending(456, 123) is not None + + def test_verify_max_attempts(self, service: VerificationService) -> None: + """Test max attempts exceeded.""" + service.create_challenge(123, 456, ChallengeType.BUTTON) + + # Use up all attempts + for _ in range(3): + service.verify(456, 123, "wrong") + + success, message = service.verify(456, 123, "verified") + assert success is False + assert "too many" in message.lower() + + def test_verify_no_pending(self, service: VerificationService) -> None: + """Test verification with no pending challenge.""" + success, message = service.verify(456, 123, "verified") + assert success is False + assert "no pending" in message.lower() + + def test_cancel(self, service: VerificationService) -> None: + """Test canceling verification.""" + service.create_challenge(123, 456, ChallengeType.BUTTON) + + assert service.cancel(456, 123) is True + assert service.get_pending(456, 123) is None + + # Cancel non-existent returns False + assert service.cancel(456, 123) is False + + def test_pending_count(self, service: VerificationService) -> None: + """Test pending count per guild.""" + service.create_challenge(1, 456, ChallengeType.BUTTON) + service.create_challenge(2, 456, ChallengeType.BUTTON) + service.create_challenge(3, 789, ChallengeType.BUTTON) + + assert service.get_pending_count(456) == 2 + assert service.get_pending_count(789) == 1 + assert service.get_pending_count(999) == 0