Add image and GIF vision support

- Add ImageAttachment dataclass for image metadata
- Update Message to support list of image attachments
- Update all providers (OpenAI, Anthropic, Gemini, OpenRouter) for vision
- Extract images from Discord attachments and embeds in ai_chat.py
- Supports PNG, JPEG, GIF, and WebP formats
This commit is contained in:
2026-01-11 20:56:50 +01:00
parent 8f521b869b
commit 4ac123be9c
8 changed files with 187 additions and 12 deletions

View File

@@ -7,7 +7,13 @@ import discord
from discord.ext import commands from discord.ext import commands
from daemon_boyfriend.config import settings from daemon_boyfriend.config import settings
from daemon_boyfriend.services import AIService, ConversationManager, Message, SearXNGService from daemon_boyfriend.services import (
AIService,
ConversationManager,
ImageAttachment,
Message,
SearXNGService,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -125,6 +131,65 @@ class AIChatCog(commands.Cog):
return content.strip() return content.strip()
def _extract_image_attachments(self, message: discord.Message) -> list[ImageAttachment]:
"""Extract image attachments from a Discord message.
Args:
message: The Discord message
Returns:
List of ImageAttachment objects
"""
images = []
# Supported image types
image_types = {
"image/png": "image/png",
"image/jpeg": "image/jpeg",
"image/jpg": "image/jpeg",
"image/gif": "image/gif",
"image/webp": "image/webp",
}
# Check message attachments
for attachment in message.attachments:
content_type = attachment.content_type or ""
if content_type in image_types:
images.append(
ImageAttachment(
url=attachment.url,
media_type=image_types[content_type],
)
)
# Also check by file extension if content_type not set
elif attachment.filename:
ext = attachment.filename.lower().split(".")[-1]
if ext in ("png", "jpg", "jpeg", "gif", "webp"):
media_type = f"image/{ext}" if ext != "jpg" else "image/jpeg"
images.append(
ImageAttachment(
url=attachment.url,
media_type=media_type,
)
)
# Check embeds for images
for embed in message.embeds:
if embed.image and embed.image.url:
# Guess media type from URL
url = embed.image.url.lower()
media_type = "image/png" # default
if ".jpg" in url or ".jpeg" in url:
media_type = "image/jpeg"
elif ".gif" in url:
media_type = "image/gif"
elif ".webp" in url:
media_type = "image/webp"
images.append(ImageAttachment(url=embed.image.url, media_type=media_type))
logger.debug(f"Extracted {len(images)} images from message")
return images
def _get_mentioned_users_context(self, message: discord.Message) -> str | None: def _get_mentioned_users_context(self, message: discord.Message) -> str | None:
"""Get context about mentioned users (excluding the bot). """Get context about mentioned users (excluding the bot).
@@ -178,8 +243,12 @@ class AIChatCog(commands.Cog):
# Get conversation history # Get conversation history
history = self.conversations.get_history(user_id) history = self.conversations.get_history(user_id)
# Add current message to history for the API call # Extract any image attachments from the message
messages = history + [Message(role="user", content=user_message)] images = self._extract_image_attachments(message)
# Add current message to history for the API call (with images if any)
current_message = Message(role="user", content=user_message, images=images)
messages = history + [current_message]
# Check if we should search the web # Check if we should search the web
search_context = await self._maybe_search(user_message) search_context = await self._maybe_search(user_message)

View File

@@ -2,12 +2,13 @@
from .ai_service import AIService from .ai_service import AIService
from .conversation import ConversationManager from .conversation import ConversationManager
from .providers import AIResponse, Message from .providers import AIResponse, ImageAttachment, Message
from .searxng import SearXNGService from .searxng import SearXNGService
__all__ = [ __all__ = [
"AIService", "AIService",
"AIResponse", "AIResponse",
"ImageAttachment",
"Message", "Message",
"ConversationManager", "ConversationManager",
"SearXNGService", "SearXNGService",

View File

@@ -1,7 +1,7 @@
"""AI Provider implementations.""" """AI Provider implementations."""
from .anthropic import AnthropicProvider from .anthropic import AnthropicProvider
from .base import AIProvider, AIResponse, Message from .base import AIProvider, AIResponse, ImageAttachment, Message
from .gemini import GeminiProvider from .gemini import GeminiProvider
from .openai import OpenAIProvider from .openai import OpenAIProvider
from .openrouter import OpenRouterProvider from .openrouter import OpenRouterProvider
@@ -9,6 +9,7 @@ from .openrouter import OpenRouterProvider
__all__ = [ __all__ = [
"AIProvider", "AIProvider",
"AIResponse", "AIResponse",
"ImageAttachment",
"Message", "Message",
"OpenAIProvider", "OpenAIProvider",
"OpenRouterProvider", "OpenRouterProvider",

View File

@@ -1,6 +1,7 @@
"""Anthropic (Claude) provider implementation.""" """Anthropic (Claude) provider implementation."""
import logging import logging
from typing import Any
import anthropic import anthropic
@@ -20,6 +21,29 @@ class AnthropicProvider(AIProvider):
def provider_name(self) -> str: def provider_name(self) -> str:
return "anthropic" return "anthropic"
def _build_message_content(self, message: Message) -> str | list[dict[str, Any]]:
"""Build message content, handling images if present."""
if not message.images:
return message.content
# Build multimodal content (Anthropic format)
content: list[dict[str, Any]] = []
for image in message.images:
content.append(
{
"type": "image",
"source": {
"type": "url",
"url": image.url,
},
}
)
content.append({"type": "text", "text": message.content})
return content
async def generate( async def generate(
self, self,
messages: list[Message], messages: list[Message],
@@ -29,7 +53,14 @@ class AnthropicProvider(AIProvider):
) -> AIResponse: ) -> AIResponse:
"""Generate a response using Claude.""" """Generate a response using Claude."""
# Build messages list (Anthropic format) # Build messages list (Anthropic format)
api_messages = [{"role": m.role, "content": m.content} for m in messages] api_messages = []
for m in messages:
api_messages.append(
{
"role": m.role,
"content": self._build_message_content(m),
}
)
logger.debug(f"Sending {len(api_messages)} messages to Anthropic") logger.debug(f"Sending {len(api_messages)} messages to Anthropic")

View File

@@ -1,7 +1,15 @@
"""Abstract base class for AI providers.""" """Abstract base class for AI providers."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass, field
@dataclass
class ImageAttachment:
"""An image attachment."""
url: str
media_type: str = "image/png" # image/png, image/jpeg, image/gif, image/webp
@dataclass @dataclass
@@ -10,6 +18,7 @@ class Message:
role: str # "user", "assistant", "system" role: str # "user", "assistant", "system"
content: str content: str
images: list[ImageAttachment] = field(default_factory=list)
@dataclass @dataclass

View File

@@ -21,6 +21,19 @@ class GeminiProvider(AIProvider):
def provider_name(self) -> str: def provider_name(self) -> str:
return "gemini" return "gemini"
def _build_message_parts(self, message: Message) -> list[types.Part]:
"""Build message parts, handling images if present."""
parts = []
# Add images first
for image in message.images:
parts.append(types.Part.from_uri(file_uri=image.url, mime_type=image.media_type))
# Add text content
parts.append(types.Part(text=message.content))
return parts
async def generate( async def generate(
self, self,
messages: list[Message], messages: list[Message],
@@ -34,7 +47,8 @@ class GeminiProvider(AIProvider):
for m in messages: for m in messages:
# Gemini uses "user" and "model" roles # Gemini uses "user" and "model" roles
role = "model" if m.role == "assistant" else m.role role = "model" if m.role == "assistant" else m.role
contents.append(types.Content(role=role, parts=[types.Part(text=m.content)])) parts = self._build_message_parts(m)
contents.append(types.Content(role=role, parts=parts))
logger.debug(f"Sending {len(contents)} messages to Gemini") logger.debug(f"Sending {len(contents)} messages to Gemini")

View File

@@ -1,6 +1,7 @@
"""OpenAI provider implementation.""" """OpenAI provider implementation."""
import logging import logging
from typing import Any
from openai import AsyncOpenAI from openai import AsyncOpenAI
@@ -20,6 +21,24 @@ class OpenAIProvider(AIProvider):
def provider_name(self) -> str: def provider_name(self) -> str:
return "openai" return "openai"
def _build_message_content(self, message: Message) -> str | list[dict[str, Any]]:
"""Build message content, handling images if present."""
if not message.images:
return message.content
# Build multimodal content
content: list[dict[str, Any]] = [{"type": "text", "text": message.content}]
for image in message.images:
content.append(
{
"type": "image_url",
"image_url": {"url": image.url},
}
)
return content
async def generate( async def generate(
self, self,
messages: list[Message], messages: list[Message],
@@ -29,12 +48,18 @@ class OpenAIProvider(AIProvider):
) -> AIResponse: ) -> AIResponse:
"""Generate a response using OpenAI.""" """Generate a response using OpenAI."""
# Build messages list # Build messages list
api_messages: list[dict[str, str]] = [] api_messages: list[dict[str, Any]] = []
if system_prompt: if system_prompt:
api_messages.append({"role": "system", "content": system_prompt}) api_messages.append({"role": "system", "content": system_prompt})
api_messages.extend([{"role": m.role, "content": m.content} for m in messages]) for m in messages:
api_messages.append(
{
"role": m.role,
"content": self._build_message_content(m),
}
)
logger.debug(f"Sending {len(api_messages)} messages to OpenAI") logger.debug(f"Sending {len(api_messages)} messages to OpenAI")

View File

@@ -4,6 +4,7 @@ OpenRouter uses an OpenAI-compatible API, so we extend the OpenAI provider.
""" """
import logging import logging
from typing import Any
from openai import AsyncOpenAI from openai import AsyncOpenAI
@@ -32,6 +33,24 @@ class OpenRouterProvider(AIProvider):
def provider_name(self) -> str: def provider_name(self) -> str:
return "openrouter" return "openrouter"
def _build_message_content(self, message: Message) -> str | list[dict[str, Any]]:
"""Build message content, handling images if present."""
if not message.images:
return message.content
# Build multimodal content
content: list[dict[str, Any]] = [{"type": "text", "text": message.content}]
for image in message.images:
content.append(
{
"type": "image_url",
"image_url": {"url": image.url},
}
)
return content
async def generate( async def generate(
self, self,
messages: list[Message], messages: list[Message],
@@ -41,12 +60,18 @@ class OpenRouterProvider(AIProvider):
) -> AIResponse: ) -> AIResponse:
"""Generate a response using OpenRouter.""" """Generate a response using OpenRouter."""
# Build messages list # Build messages list
api_messages: list[dict[str, str]] = [] api_messages: list[dict[str, Any]] = []
if system_prompt: if system_prompt:
api_messages.append({"role": "system", "content": system_prompt}) api_messages.append({"role": "system", "content": system_prompt})
api_messages.extend([{"role": m.role, "content": m.content} for m in messages]) for m in messages:
api_messages.append(
{
"role": m.role,
"content": self._build_message_content(m),
}
)
logger.debug(f"Sending {len(api_messages)} messages to OpenRouter ({self.model})") logger.debug(f"Sending {len(api_messages)} messages to OpenRouter ({self.model})")