Add image and GIF vision support
- Add ImageAttachment dataclass for image metadata - Update Message to support list of image attachments - Update all providers (OpenAI, Anthropic, Gemini, OpenRouter) for vision - Extract images from Discord attachments and embeds in ai_chat.py - Supports PNG, JPEG, GIF, and WebP formats
This commit is contained in:
@@ -7,7 +7,13 @@ import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from daemon_boyfriend.config import settings
|
||||
from daemon_boyfriend.services import AIService, ConversationManager, Message, SearXNGService
|
||||
from daemon_boyfriend.services import (
|
||||
AIService,
|
||||
ConversationManager,
|
||||
ImageAttachment,
|
||||
Message,
|
||||
SearXNGService,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -125,6 +131,65 @@ class AIChatCog(commands.Cog):
|
||||
|
||||
return content.strip()
|
||||
|
||||
def _extract_image_attachments(self, message: discord.Message) -> list[ImageAttachment]:
|
||||
"""Extract image attachments from a Discord message.
|
||||
|
||||
Args:
|
||||
message: The Discord message
|
||||
|
||||
Returns:
|
||||
List of ImageAttachment objects
|
||||
"""
|
||||
images = []
|
||||
|
||||
# Supported image types
|
||||
image_types = {
|
||||
"image/png": "image/png",
|
||||
"image/jpeg": "image/jpeg",
|
||||
"image/jpg": "image/jpeg",
|
||||
"image/gif": "image/gif",
|
||||
"image/webp": "image/webp",
|
||||
}
|
||||
|
||||
# Check message attachments
|
||||
for attachment in message.attachments:
|
||||
content_type = attachment.content_type or ""
|
||||
if content_type in image_types:
|
||||
images.append(
|
||||
ImageAttachment(
|
||||
url=attachment.url,
|
||||
media_type=image_types[content_type],
|
||||
)
|
||||
)
|
||||
# Also check by file extension if content_type not set
|
||||
elif attachment.filename:
|
||||
ext = attachment.filename.lower().split(".")[-1]
|
||||
if ext in ("png", "jpg", "jpeg", "gif", "webp"):
|
||||
media_type = f"image/{ext}" if ext != "jpg" else "image/jpeg"
|
||||
images.append(
|
||||
ImageAttachment(
|
||||
url=attachment.url,
|
||||
media_type=media_type,
|
||||
)
|
||||
)
|
||||
|
||||
# Check embeds for images
|
||||
for embed in message.embeds:
|
||||
if embed.image and embed.image.url:
|
||||
# Guess media type from URL
|
||||
url = embed.image.url.lower()
|
||||
media_type = "image/png" # default
|
||||
if ".jpg" in url or ".jpeg" in url:
|
||||
media_type = "image/jpeg"
|
||||
elif ".gif" in url:
|
||||
media_type = "image/gif"
|
||||
elif ".webp" in url:
|
||||
media_type = "image/webp"
|
||||
images.append(ImageAttachment(url=embed.image.url, media_type=media_type))
|
||||
|
||||
logger.debug(f"Extracted {len(images)} images from message")
|
||||
return images
|
||||
|
||||
def _get_mentioned_users_context(self, message: discord.Message) -> str | None:
|
||||
"""Get context about mentioned users (excluding the bot).
|
||||
|
||||
@@ -178,8 +243,12 @@ class AIChatCog(commands.Cog):
|
||||
# Get conversation history
|
||||
history = self.conversations.get_history(user_id)
|
||||
|
||||
# Add current message to history for the API call
|
||||
messages = history + [Message(role="user", content=user_message)]
|
||||
# Extract any image attachments from the message
|
||||
images = self._extract_image_attachments(message)
|
||||
|
||||
# Add current message to history for the API call (with images if any)
|
||||
current_message = Message(role="user", content=user_message, images=images)
|
||||
messages = history + [current_message]
|
||||
|
||||
# Check if we should search the web
|
||||
search_context = await self._maybe_search(user_message)
|
||||
|
||||
@@ -2,12 +2,13 @@
|
||||
|
||||
from .ai_service import AIService
|
||||
from .conversation import ConversationManager
|
||||
from .providers import AIResponse, Message
|
||||
from .providers import AIResponse, ImageAttachment, Message
|
||||
from .searxng import SearXNGService
|
||||
|
||||
__all__ = [
|
||||
"AIService",
|
||||
"AIResponse",
|
||||
"ImageAttachment",
|
||||
"Message",
|
||||
"ConversationManager",
|
||||
"SearXNGService",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""AI Provider implementations."""
|
||||
|
||||
from .anthropic import AnthropicProvider
|
||||
from .base import AIProvider, AIResponse, Message
|
||||
from .base import AIProvider, AIResponse, ImageAttachment, Message
|
||||
from .gemini import GeminiProvider
|
||||
from .openai import OpenAIProvider
|
||||
from .openrouter import OpenRouterProvider
|
||||
@@ -9,6 +9,7 @@ from .openrouter import OpenRouterProvider
|
||||
__all__ = [
|
||||
"AIProvider",
|
||||
"AIResponse",
|
||||
"ImageAttachment",
|
||||
"Message",
|
||||
"OpenAIProvider",
|
||||
"OpenRouterProvider",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Anthropic (Claude) provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import anthropic
|
||||
|
||||
@@ -20,6 +21,29 @@ class AnthropicProvider(AIProvider):
|
||||
def provider_name(self) -> str:
|
||||
return "anthropic"
|
||||
|
||||
def _build_message_content(self, message: Message) -> str | list[dict[str, Any]]:
|
||||
"""Build message content, handling images if present."""
|
||||
if not message.images:
|
||||
return message.content
|
||||
|
||||
# Build multimodal content (Anthropic format)
|
||||
content: list[dict[str, Any]] = []
|
||||
|
||||
for image in message.images:
|
||||
content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "url",
|
||||
"url": image.url,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
content.append({"type": "text", "text": message.content})
|
||||
|
||||
return content
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
messages: list[Message],
|
||||
@@ -29,7 +53,14 @@ class AnthropicProvider(AIProvider):
|
||||
) -> AIResponse:
|
||||
"""Generate a response using Claude."""
|
||||
# Build messages list (Anthropic format)
|
||||
api_messages = [{"role": m.role, "content": m.content} for m in messages]
|
||||
api_messages = []
|
||||
for m in messages:
|
||||
api_messages.append(
|
||||
{
|
||||
"role": m.role,
|
||||
"content": self._build_message_content(m),
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(f"Sending {len(api_messages)} messages to Anthropic")
|
||||
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
"""Abstract base class for AI providers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageAttachment:
|
||||
"""An image attachment."""
|
||||
|
||||
url: str
|
||||
media_type: str = "image/png" # image/png, image/jpeg, image/gif, image/webp
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -10,6 +18,7 @@ class Message:
|
||||
|
||||
role: str # "user", "assistant", "system"
|
||||
content: str
|
||||
images: list[ImageAttachment] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -21,6 +21,19 @@ class GeminiProvider(AIProvider):
|
||||
def provider_name(self) -> str:
|
||||
return "gemini"
|
||||
|
||||
def _build_message_parts(self, message: Message) -> list[types.Part]:
|
||||
"""Build message parts, handling images if present."""
|
||||
parts = []
|
||||
|
||||
# Add images first
|
||||
for image in message.images:
|
||||
parts.append(types.Part.from_uri(file_uri=image.url, mime_type=image.media_type))
|
||||
|
||||
# Add text content
|
||||
parts.append(types.Part(text=message.content))
|
||||
|
||||
return parts
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
messages: list[Message],
|
||||
@@ -34,7 +47,8 @@ class GeminiProvider(AIProvider):
|
||||
for m in messages:
|
||||
# Gemini uses "user" and "model" roles
|
||||
role = "model" if m.role == "assistant" else m.role
|
||||
contents.append(types.Content(role=role, parts=[types.Part(text=m.content)]))
|
||||
parts = self._build_message_parts(m)
|
||||
contents.append(types.Content(role=role, parts=parts))
|
||||
|
||||
logger.debug(f"Sending {len(contents)} messages to Gemini")
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""OpenAI provider implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
@@ -20,6 +21,24 @@ class OpenAIProvider(AIProvider):
|
||||
def provider_name(self) -> str:
|
||||
return "openai"
|
||||
|
||||
def _build_message_content(self, message: Message) -> str | list[dict[str, Any]]:
|
||||
"""Build message content, handling images if present."""
|
||||
if not message.images:
|
||||
return message.content
|
||||
|
||||
# Build multimodal content
|
||||
content: list[dict[str, Any]] = [{"type": "text", "text": message.content}]
|
||||
|
||||
for image in message.images:
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image.url},
|
||||
}
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
messages: list[Message],
|
||||
@@ -29,12 +48,18 @@ class OpenAIProvider(AIProvider):
|
||||
) -> AIResponse:
|
||||
"""Generate a response using OpenAI."""
|
||||
# Build messages list
|
||||
api_messages: list[dict[str, str]] = []
|
||||
api_messages: list[dict[str, Any]] = []
|
||||
|
||||
if system_prompt:
|
||||
api_messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
api_messages.extend([{"role": m.role, "content": m.content} for m in messages])
|
||||
for m in messages:
|
||||
api_messages.append(
|
||||
{
|
||||
"role": m.role,
|
||||
"content": self._build_message_content(m),
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(f"Sending {len(api_messages)} messages to OpenAI")
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ OpenRouter uses an OpenAI-compatible API, so we extend the OpenAI provider.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
@@ -32,6 +33,24 @@ class OpenRouterProvider(AIProvider):
|
||||
def provider_name(self) -> str:
|
||||
return "openrouter"
|
||||
|
||||
def _build_message_content(self, message: Message) -> str | list[dict[str, Any]]:
|
||||
"""Build message content, handling images if present."""
|
||||
if not message.images:
|
||||
return message.content
|
||||
|
||||
# Build multimodal content
|
||||
content: list[dict[str, Any]] = [{"type": "text", "text": message.content}]
|
||||
|
||||
for image in message.images:
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image.url},
|
||||
}
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
messages: list[Message],
|
||||
@@ -41,12 +60,18 @@ class OpenRouterProvider(AIProvider):
|
||||
) -> AIResponse:
|
||||
"""Generate a response using OpenRouter."""
|
||||
# Build messages list
|
||||
api_messages: list[dict[str, str]] = []
|
||||
api_messages: list[dict[str, Any]] = []
|
||||
|
||||
if system_prompt:
|
||||
api_messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
api_messages.extend([{"role": m.role, "content": m.content} for m in messages])
|
||||
for m in messages:
|
||||
api_messages.append(
|
||||
{
|
||||
"role": m.role,
|
||||
"content": self._build_message_content(m),
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(f"Sending {len(api_messages)} messages to OpenRouter ({self.model})")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user