Add image and GIF vision support
- Add ImageAttachment dataclass for image metadata - Update Message to support list of image attachments - Update all providers (OpenAI, Anthropic, Gemini, OpenRouter) for vision - Extract images from Discord attachments and embeds in ai_chat.py - Supports PNG, JPEG, GIF, and WebP formats
This commit is contained in:
@@ -7,7 +7,13 @@ import discord
|
|||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
from daemon_boyfriend.config import settings
|
from daemon_boyfriend.config import settings
|
||||||
from daemon_boyfriend.services import AIService, ConversationManager, Message, SearXNGService
|
from daemon_boyfriend.services import (
|
||||||
|
AIService,
|
||||||
|
ConversationManager,
|
||||||
|
ImageAttachment,
|
||||||
|
Message,
|
||||||
|
SearXNGService,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -125,6 +131,65 @@ class AIChatCog(commands.Cog):
|
|||||||
|
|
||||||
return content.strip()
|
return content.strip()
|
||||||
|
|
||||||
|
def _extract_image_attachments(self, message: discord.Message) -> list[ImageAttachment]:
|
||||||
|
"""Extract image attachments from a Discord message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The Discord message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ImageAttachment objects
|
||||||
|
"""
|
||||||
|
images = []
|
||||||
|
|
||||||
|
# Supported image types
|
||||||
|
image_types = {
|
||||||
|
"image/png": "image/png",
|
||||||
|
"image/jpeg": "image/jpeg",
|
||||||
|
"image/jpg": "image/jpeg",
|
||||||
|
"image/gif": "image/gif",
|
||||||
|
"image/webp": "image/webp",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check message attachments
|
||||||
|
for attachment in message.attachments:
|
||||||
|
content_type = attachment.content_type or ""
|
||||||
|
if content_type in image_types:
|
||||||
|
images.append(
|
||||||
|
ImageAttachment(
|
||||||
|
url=attachment.url,
|
||||||
|
media_type=image_types[content_type],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Also check by file extension if content_type not set
|
||||||
|
elif attachment.filename:
|
||||||
|
ext = attachment.filename.lower().split(".")[-1]
|
||||||
|
if ext in ("png", "jpg", "jpeg", "gif", "webp"):
|
||||||
|
media_type = f"image/{ext}" if ext != "jpg" else "image/jpeg"
|
||||||
|
images.append(
|
||||||
|
ImageAttachment(
|
||||||
|
url=attachment.url,
|
||||||
|
media_type=media_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check embeds for images
|
||||||
|
for embed in message.embeds:
|
||||||
|
if embed.image and embed.image.url:
|
||||||
|
# Guess media type from URL
|
||||||
|
url = embed.image.url.lower()
|
||||||
|
media_type = "image/png" # default
|
||||||
|
if ".jpg" in url or ".jpeg" in url:
|
||||||
|
media_type = "image/jpeg"
|
||||||
|
elif ".gif" in url:
|
||||||
|
media_type = "image/gif"
|
||||||
|
elif ".webp" in url:
|
||||||
|
media_type = "image/webp"
|
||||||
|
images.append(ImageAttachment(url=embed.image.url, media_type=media_type))
|
||||||
|
|
||||||
|
logger.debug(f"Extracted {len(images)} images from message")
|
||||||
|
return images
|
||||||
|
|
||||||
def _get_mentioned_users_context(self, message: discord.Message) -> str | None:
|
def _get_mentioned_users_context(self, message: discord.Message) -> str | None:
|
||||||
"""Get context about mentioned users (excluding the bot).
|
"""Get context about mentioned users (excluding the bot).
|
||||||
|
|
||||||
@@ -178,8 +243,12 @@ class AIChatCog(commands.Cog):
|
|||||||
# Get conversation history
|
# Get conversation history
|
||||||
history = self.conversations.get_history(user_id)
|
history = self.conversations.get_history(user_id)
|
||||||
|
|
||||||
# Add current message to history for the API call
|
# Extract any image attachments from the message
|
||||||
messages = history + [Message(role="user", content=user_message)]
|
images = self._extract_image_attachments(message)
|
||||||
|
|
||||||
|
# Add current message to history for the API call (with images if any)
|
||||||
|
current_message = Message(role="user", content=user_message, images=images)
|
||||||
|
messages = history + [current_message]
|
||||||
|
|
||||||
# Check if we should search the web
|
# Check if we should search the web
|
||||||
search_context = await self._maybe_search(user_message)
|
search_context = await self._maybe_search(user_message)
|
||||||
|
|||||||
@@ -2,12 +2,13 @@
|
|||||||
|
|
||||||
from .ai_service import AIService
|
from .ai_service import AIService
|
||||||
from .conversation import ConversationManager
|
from .conversation import ConversationManager
|
||||||
from .providers import AIResponse, Message
|
from .providers import AIResponse, ImageAttachment, Message
|
||||||
from .searxng import SearXNGService
|
from .searxng import SearXNGService
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AIService",
|
"AIService",
|
||||||
"AIResponse",
|
"AIResponse",
|
||||||
|
"ImageAttachment",
|
||||||
"Message",
|
"Message",
|
||||||
"ConversationManager",
|
"ConversationManager",
|
||||||
"SearXNGService",
|
"SearXNGService",
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""AI Provider implementations."""
|
"""AI Provider implementations."""
|
||||||
|
|
||||||
from .anthropic import AnthropicProvider
|
from .anthropic import AnthropicProvider
|
||||||
from .base import AIProvider, AIResponse, Message
|
from .base import AIProvider, AIResponse, ImageAttachment, Message
|
||||||
from .gemini import GeminiProvider
|
from .gemini import GeminiProvider
|
||||||
from .openai import OpenAIProvider
|
from .openai import OpenAIProvider
|
||||||
from .openrouter import OpenRouterProvider
|
from .openrouter import OpenRouterProvider
|
||||||
@@ -9,6 +9,7 @@ from .openrouter import OpenRouterProvider
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"AIProvider",
|
"AIProvider",
|
||||||
"AIResponse",
|
"AIResponse",
|
||||||
|
"ImageAttachment",
|
||||||
"Message",
|
"Message",
|
||||||
"OpenAIProvider",
|
"OpenAIProvider",
|
||||||
"OpenRouterProvider",
|
"OpenRouterProvider",
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Anthropic (Claude) provider implementation."""
|
"""Anthropic (Claude) provider implementation."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
|
|
||||||
@@ -20,6 +21,29 @@ class AnthropicProvider(AIProvider):
|
|||||||
def provider_name(self) -> str:
|
def provider_name(self) -> str:
|
||||||
return "anthropic"
|
return "anthropic"
|
||||||
|
|
||||||
|
def _build_message_content(self, message: Message) -> str | list[dict[str, Any]]:
|
||||||
|
"""Build message content, handling images if present."""
|
||||||
|
if not message.images:
|
||||||
|
return message.content
|
||||||
|
|
||||||
|
# Build multimodal content (Anthropic format)
|
||||||
|
content: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for image in message.images:
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "url",
|
||||||
|
"url": image.url,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
content.append({"type": "text", "text": message.content})
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
@@ -29,7 +53,14 @@ class AnthropicProvider(AIProvider):
|
|||||||
) -> AIResponse:
|
) -> AIResponse:
|
||||||
"""Generate a response using Claude."""
|
"""Generate a response using Claude."""
|
||||||
# Build messages list (Anthropic format)
|
# Build messages list (Anthropic format)
|
||||||
api_messages = [{"role": m.role, "content": m.content} for m in messages]
|
api_messages = []
|
||||||
|
for m in messages:
|
||||||
|
api_messages.append(
|
||||||
|
{
|
||||||
|
"role": m.role,
|
||||||
|
"content": self._build_message_content(m),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f"Sending {len(api_messages)} messages to Anthropic")
|
logger.debug(f"Sending {len(api_messages)} messages to Anthropic")
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,15 @@
|
|||||||
"""Abstract base class for AI providers."""
|
"""Abstract base class for AI providers."""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageAttachment:
|
||||||
|
"""An image attachment."""
|
||||||
|
|
||||||
|
url: str
|
||||||
|
media_type: str = "image/png" # image/png, image/jpeg, image/gif, image/webp
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -10,6 +18,7 @@ class Message:
|
|||||||
|
|
||||||
role: str # "user", "assistant", "system"
|
role: str # "user", "assistant", "system"
|
||||||
content: str
|
content: str
|
||||||
|
images: list[ImageAttachment] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -21,6 +21,19 @@ class GeminiProvider(AIProvider):
|
|||||||
def provider_name(self) -> str:
|
def provider_name(self) -> str:
|
||||||
return "gemini"
|
return "gemini"
|
||||||
|
|
||||||
|
def _build_message_parts(self, message: Message) -> list[types.Part]:
|
||||||
|
"""Build message parts, handling images if present."""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
# Add images first
|
||||||
|
for image in message.images:
|
||||||
|
parts.append(types.Part.from_uri(file_uri=image.url, mime_type=image.media_type))
|
||||||
|
|
||||||
|
# Add text content
|
||||||
|
parts.append(types.Part(text=message.content))
|
||||||
|
|
||||||
|
return parts
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
@@ -34,7 +47,8 @@ class GeminiProvider(AIProvider):
|
|||||||
for m in messages:
|
for m in messages:
|
||||||
# Gemini uses "user" and "model" roles
|
# Gemini uses "user" and "model" roles
|
||||||
role = "model" if m.role == "assistant" else m.role
|
role = "model" if m.role == "assistant" else m.role
|
||||||
contents.append(types.Content(role=role, parts=[types.Part(text=m.content)]))
|
parts = self._build_message_parts(m)
|
||||||
|
contents.append(types.Content(role=role, parts=parts))
|
||||||
|
|
||||||
logger.debug(f"Sending {len(contents)} messages to Gemini")
|
logger.debug(f"Sending {len(contents)} messages to Gemini")
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""OpenAI provider implementation."""
|
"""OpenAI provider implementation."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
@@ -20,6 +21,24 @@ class OpenAIProvider(AIProvider):
|
|||||||
def provider_name(self) -> str:
|
def provider_name(self) -> str:
|
||||||
return "openai"
|
return "openai"
|
||||||
|
|
||||||
|
def _build_message_content(self, message: Message) -> str | list[dict[str, Any]]:
|
||||||
|
"""Build message content, handling images if present."""
|
||||||
|
if not message.images:
|
||||||
|
return message.content
|
||||||
|
|
||||||
|
# Build multimodal content
|
||||||
|
content: list[dict[str, Any]] = [{"type": "text", "text": message.content}]
|
||||||
|
|
||||||
|
for image in message.images:
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": image.url},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
@@ -29,12 +48,18 @@ class OpenAIProvider(AIProvider):
|
|||||||
) -> AIResponse:
|
) -> AIResponse:
|
||||||
"""Generate a response using OpenAI."""
|
"""Generate a response using OpenAI."""
|
||||||
# Build messages list
|
# Build messages list
|
||||||
api_messages: list[dict[str, str]] = []
|
api_messages: list[dict[str, Any]] = []
|
||||||
|
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
api_messages.append({"role": "system", "content": system_prompt})
|
api_messages.append({"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
api_messages.extend([{"role": m.role, "content": m.content} for m in messages])
|
for m in messages:
|
||||||
|
api_messages.append(
|
||||||
|
{
|
||||||
|
"role": m.role,
|
||||||
|
"content": self._build_message_content(m),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f"Sending {len(api_messages)} messages to OpenAI")
|
logger.debug(f"Sending {len(api_messages)} messages to OpenAI")
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ OpenRouter uses an OpenAI-compatible API, so we extend the OpenAI provider.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
@@ -32,6 +33,24 @@ class OpenRouterProvider(AIProvider):
|
|||||||
def provider_name(self) -> str:
|
def provider_name(self) -> str:
|
||||||
return "openrouter"
|
return "openrouter"
|
||||||
|
|
||||||
|
def _build_message_content(self, message: Message) -> str | list[dict[str, Any]]:
|
||||||
|
"""Build message content, handling images if present."""
|
||||||
|
if not message.images:
|
||||||
|
return message.content
|
||||||
|
|
||||||
|
# Build multimodal content
|
||||||
|
content: list[dict[str, Any]] = [{"type": "text", "text": message.content}]
|
||||||
|
|
||||||
|
for image in message.images:
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": image.url},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
@@ -41,12 +60,18 @@ class OpenRouterProvider(AIProvider):
|
|||||||
) -> AIResponse:
|
) -> AIResponse:
|
||||||
"""Generate a response using OpenRouter."""
|
"""Generate a response using OpenRouter."""
|
||||||
# Build messages list
|
# Build messages list
|
||||||
api_messages: list[dict[str, str]] = []
|
api_messages: list[dict[str, Any]] = []
|
||||||
|
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
api_messages.append({"role": "system", "content": system_prompt})
|
api_messages.append({"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
api_messages.extend([{"role": m.role, "content": m.content} for m in messages])
|
for m in messages:
|
||||||
|
api_messages.append(
|
||||||
|
{
|
||||||
|
"role": m.role,
|
||||||
|
"content": self._build_message_content(m),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(f"Sending {len(api_messages)} messages to OpenRouter ({self.model})")
|
logger.debug(f"Sending {len(api_messages)} messages to OpenRouter ({self.model})")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user