234 lines
8.0 KiB
Python
234 lines
8.0 KiB
Python
"""Opinion Service - manages bot opinion formation on topics."""
|
|
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from daemon_boyfriend.models import BotOpinion
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OpinionService:
|
|
"""Manages bot opinion formation and topic preferences."""
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
self._session = session
|
|
|
|
async def get_opinion(self, topic: str, guild_id: int | None = None) -> BotOpinion | None:
|
|
"""Get the bot's opinion on a topic."""
|
|
stmt = select(BotOpinion).where(
|
|
BotOpinion.topic == topic.lower(),
|
|
BotOpinion.guild_id == guild_id,
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_or_create_opinion(self, topic: str, guild_id: int | None = None) -> BotOpinion:
|
|
"""Get or create an opinion on a topic."""
|
|
opinion = await self.get_opinion(topic, guild_id)
|
|
|
|
if not opinion:
|
|
opinion = BotOpinion(
|
|
topic=topic.lower(),
|
|
guild_id=guild_id,
|
|
sentiment=0.0,
|
|
interest_level=0.5,
|
|
discussion_count=0,
|
|
)
|
|
self._session.add(opinion)
|
|
await self._session.flush()
|
|
|
|
return opinion
|
|
|
|
async def record_topic_discussion(
|
|
self,
|
|
topic: str,
|
|
guild_id: int | None,
|
|
sentiment: float,
|
|
engagement_level: float,
|
|
) -> BotOpinion:
|
|
"""Record a discussion about a topic, updating the bot's opinion.
|
|
|
|
Args:
|
|
topic: The topic discussed
|
|
guild_id: Guild ID or None for global
|
|
sentiment: How positive the discussion was (-1 to 1)
|
|
engagement_level: How engaging the discussion was (0 to 1)
|
|
|
|
Returns:
|
|
Updated opinion
|
|
"""
|
|
opinion = await self.get_or_create_opinion(topic, guild_id)
|
|
|
|
# Increment discussion count
|
|
opinion.discussion_count += 1
|
|
|
|
# Update sentiment (weighted average, newer discussions have more weight)
|
|
weight = 0.2 # 20% weight to new data
|
|
opinion.sentiment = (opinion.sentiment * (1 - weight)) + (sentiment * weight)
|
|
opinion.sentiment = max(-1.0, min(1.0, opinion.sentiment))
|
|
|
|
# Update interest level based on engagement
|
|
opinion.interest_level = (opinion.interest_level * (1 - weight)) + (
|
|
engagement_level * weight
|
|
)
|
|
opinion.interest_level = max(0.0, min(1.0, opinion.interest_level))
|
|
|
|
opinion.last_reinforced_at = datetime.now(timezone.utc)
|
|
|
|
logger.debug(
|
|
f"Updated opinion on '{topic}': sentiment={opinion.sentiment:.2f}, "
|
|
f"interest={opinion.interest_level:.2f}, discussions={opinion.discussion_count}"
|
|
)
|
|
|
|
return opinion
|
|
|
|
async def set_opinion_reasoning(self, topic: str, guild_id: int | None, reasoning: str) -> None:
|
|
"""Set the reasoning for an opinion (AI-generated explanation)."""
|
|
opinion = await self.get_or_create_opinion(topic, guild_id)
|
|
opinion.reasoning = reasoning
|
|
|
|
async def get_top_interests(
|
|
self, guild_id: int | None = None, limit: int = 5
|
|
) -> list[BotOpinion]:
|
|
"""Get the bot's top interests (highest interest level + positive sentiment)."""
|
|
stmt = (
|
|
select(BotOpinion)
|
|
.where(
|
|
BotOpinion.guild_id == guild_id,
|
|
BotOpinion.discussion_count >= 3, # Only topics discussed at least 3 times
|
|
)
|
|
.order_by((BotOpinion.interest_level + BotOpinion.sentiment).desc())
|
|
.limit(limit)
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def get_relevant_opinions(
|
|
self, topics: list[str], guild_id: int | None = None
|
|
) -> list[BotOpinion]:
|
|
"""Get opinions relevant to a list of topics."""
|
|
if not topics:
|
|
return []
|
|
|
|
topics_lower = [t.lower() for t in topics]
|
|
stmt = select(BotOpinion).where(
|
|
BotOpinion.topic.in_(topics_lower),
|
|
BotOpinion.guild_id == guild_id,
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
def get_opinion_prompt_modifier(self, opinions: list[BotOpinion]) -> str:
|
|
"""Generate prompt text based on relevant opinions."""
|
|
if not opinions:
|
|
return ""
|
|
|
|
parts = []
|
|
for op in opinions[:3]: # Limit to 3 opinions
|
|
if op.sentiment > 0.5:
|
|
parts.append(f"You really enjoy discussing {op.topic}")
|
|
elif op.sentiment > 0.2:
|
|
parts.append(f"You find {op.topic} interesting")
|
|
elif op.sentiment < -0.3:
|
|
parts.append(f"You're not particularly enthusiastic about {op.topic}")
|
|
|
|
if op.reasoning:
|
|
parts.append(f"({op.reasoning})")
|
|
|
|
return "; ".join(parts) if parts else ""
|
|
|
|
async def get_all_opinions(self, guild_id: int | None = None) -> list[BotOpinion]:
|
|
"""Get all opinions for a guild."""
|
|
stmt = (
|
|
select(BotOpinion)
|
|
.where(BotOpinion.guild_id == guild_id)
|
|
.order_by(BotOpinion.discussion_count.desc())
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
def extract_topics_from_message(message: str) -> list[str]:
|
|
"""Extract potential topics from a message.
|
|
|
|
This is a simple keyword-based extraction. In production,
|
|
you might want to use NLP or an LLM for better extraction.
|
|
"""
|
|
# Common topic categories
|
|
topic_keywords = {
|
|
# Hobbies
|
|
"gaming": [
|
|
"game",
|
|
"gaming",
|
|
"video game",
|
|
"play",
|
|
"xbox",
|
|
"playstation",
|
|
"nintendo",
|
|
"steam",
|
|
],
|
|
"music": [
|
|
"music",
|
|
"song",
|
|
"band",
|
|
"album",
|
|
"concert",
|
|
"listen",
|
|
"spotify",
|
|
"guitar",
|
|
"piano",
|
|
],
|
|
"movies": ["movie", "film", "cinema", "watch", "netflix", "show", "series", "tv"],
|
|
"reading": ["book", "read", "novel", "author", "library", "kindle"],
|
|
"sports": [
|
|
"sports",
|
|
"football",
|
|
"soccer",
|
|
"basketball",
|
|
"tennis",
|
|
"golf",
|
|
"gym",
|
|
"workout",
|
|
],
|
|
"cooking": ["cook", "recipe", "food", "restaurant", "meal", "kitchen", "baking"],
|
|
"travel": ["travel", "trip", "vacation", "flight", "hotel", "country", "visit"],
|
|
"art": ["art", "painting", "drawing", "museum", "gallery", "creative"],
|
|
# Tech
|
|
"programming": [
|
|
"code",
|
|
"programming",
|
|
"developer",
|
|
"software",
|
|
"python",
|
|
"javascript",
|
|
"api",
|
|
],
|
|
"technology": ["tech", "computer", "phone", "app", "website", "internet"],
|
|
"ai": ["ai", "artificial intelligence", "machine learning", "chatgpt", "gpt"],
|
|
# Life
|
|
"work": ["work", "job", "office", "career", "boss", "colleague", "meeting"],
|
|
"family": ["family", "parents", "mom", "dad", "brother", "sister", "kids"],
|
|
"pets": ["pet", "dog", "cat", "puppy", "kitten", "animal"],
|
|
"health": ["health", "doctor", "exercise", "diet", "sleep", "medical"],
|
|
# Interests
|
|
"philosophy": ["philosophy", "meaning", "life", "existence", "think", "believe"],
|
|
"science": ["science", "research", "study", "experiment", "discovery"],
|
|
"nature": ["nature", "outdoor", "hiking", "camping", "mountain", "beach", "forest"],
|
|
}
|
|
|
|
message_lower = message.lower()
|
|
found_topics = []
|
|
|
|
for topic, keywords in topic_keywords.items():
|
|
for keyword in keywords:
|
|
if keyword in message_lower:
|
|
if topic not in found_topics:
|
|
found_topics.append(topic)
|
|
break
|
|
|
|
return found_topics
|