Files
loyal_companion/src/daemon_boyfriend/services/proactive_service.py
2026-01-12 20:30:59 +01:00

456 lines
16 KiB
Python

"""Proactive Service - manages scheduled events and proactive behavior."""
import json
import logging
import re
from datetime import datetime, timedelta, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from daemon_boyfriend.models import ScheduledEvent, User
from .providers import Message
logger = logging.getLogger(__name__)
class ProactiveService:
"""Manages scheduled events and proactive behavior."""
def __init__(self, session: AsyncSession, ai_service=None) -> None:
self._session = session
self._ai_service = ai_service
async def detect_and_schedule_followup(
self,
user: User,
message_content: str,
guild_id: int | None,
channel_id: int,
) -> ScheduledEvent | None:
"""Detect if a message mentions a future event worth following up on.
Args:
user: The user who sent the message
message_content: The message content
guild_id: Guild ID
channel_id: Channel ID for the follow-up
Returns:
Scheduled event if one was created, None otherwise
"""
if not self._ai_service:
# Use simple pattern matching as fallback
return await self._detect_followup_simple(user, message_content, guild_id, channel_id)
try:
detection_prompt = """Analyze if this message mentions a future event worth following up on.
Events like: job interviews, exams, trips, appointments, projects due, important meetings, etc.
Return JSON: {"has_event": true/false, "event_type": "...", "days_until": <number or null>, "description": "..."}
Rules:
- Only return has_event=true for significant events the speaker would appreciate being asked about later
- days_until should be your best estimate of days until the event (1 for tomorrow, 7 for next week, etc.)
- Skip casual mentions like "I might do something" or past events
- description should be a brief summary of the event
Examples:
"I have a job interview tomorrow" -> {"has_event": true, "event_type": "job interview", "days_until": 1, "description": "job interview"}
"I went to the store" -> {"has_event": false}
"My exam is next week" -> {"has_event": true, "event_type": "exam", "days_until": 7, "description": "upcoming exam"}
"""
response = await self._ai_service.chat(
messages=[Message(role="user", content=message_content)],
system_prompt=detection_prompt,
)
result = self._parse_json_response(response.content)
if result and result.get("has_event"):
days_until = result.get("days_until", 1) or 1
# Schedule follow-up for 1 day after the event
trigger_at = datetime.now(timezone.utc) + timedelta(days=days_until + 1)
event = ScheduledEvent(
user_id=user.id,
guild_id=guild_id,
channel_id=channel_id,
event_type="follow_up",
trigger_at=trigger_at,
title=f"Follow up: {result.get('event_type', 'event')}",
context={
"original_topic": result.get("description", "their event"),
"detected_from": message_content[:200],
},
)
self._session.add(event)
await self._session.flush()
logger.info(
f"Scheduled follow-up for user {user.id}: "
f"{result.get('event_type')} in {days_until + 1} days"
)
return event
except Exception as e:
logger.warning(f"Follow-up detection failed: {e}")
return None
async def _detect_followup_simple(
self,
user: User,
message_content: str,
guild_id: int | None,
channel_id: int,
) -> ScheduledEvent | None:
"""Simple pattern-based follow-up detection."""
message_lower = message_content.lower()
# Event patterns with their typical timeframes
event_patterns = {
r"(interview|job interview)": ("job interview", 1),
r"(exam|test|quiz)": ("exam", 1),
r"(presentation|presenting)": ("presentation", 1),
r"(surgery|operation|medical)": ("medical procedure", 2),
r"(moving|move to|new apartment|new house)": ("moving", 7),
r"(wedding|getting married)": ("wedding", 1),
r"(vacation|holiday|trip to)": ("trip", 7),
r"(deadline|due date|project due)": ("deadline", 1),
r"(starting.*job|new job|first day)": ("new job", 1),
r"(graduation|graduating)": ("graduation", 1),
}
# Time indicators
time_patterns = {
r"tomorrow": 1,
r"next week": 7,
r"this week": 3,
r"in (\d+) days?": None, # Extract number
r"next month": 30,
r"this weekend": 3,
}
# Check for event + time combination
detected_event = None
event_name = None
days_until = 1 # Default
for pattern, (name, default_days) in event_patterns.items():
if re.search(pattern, message_lower):
detected_event = pattern
event_name = name
days_until = default_days
break
if not detected_event:
return None
# Refine timing based on time indicators
for pattern, days in time_patterns.items():
match = re.search(pattern, message_lower)
if match:
if days is None and match.groups():
days_until = int(match.group(1))
elif days:
days_until = days
break
# Create the event
trigger_at = datetime.now(timezone.utc) + timedelta(days=days_until + 1)
event = ScheduledEvent(
user_id=user.id,
guild_id=guild_id,
channel_id=channel_id,
event_type="follow_up",
trigger_at=trigger_at,
title=f"Follow up: {event_name}",
context={
"original_topic": event_name,
"detected_from": message_content[:200],
},
)
self._session.add(event)
await self._session.flush()
logger.info(
f"Scheduled follow-up (simple) for user {user.id}: {event_name} in {days_until + 1} days"
)
return event
async def detect_and_schedule_birthday(
self,
user: User,
message_content: str,
guild_id: int | None,
channel_id: int,
) -> ScheduledEvent | None:
"""Detect birthday mentions and schedule wishes."""
birthday = self._extract_birthday(message_content)
if not birthday:
return None
# Check if we already have a birthday scheduled for this user
existing = await self._get_existing_birthday(user.id, guild_id)
if existing:
# Update the existing birthday
existing.trigger_at = self._next_birthday(birthday)
existing.context = {"birthday_date": birthday.isoformat()}
return existing
# Schedule for next occurrence
trigger_at = self._next_birthday(birthday)
event = ScheduledEvent(
user_id=user.id,
guild_id=guild_id,
channel_id=channel_id,
event_type="birthday",
trigger_at=trigger_at,
title="Birthday wish",
context={"birthday_date": birthday.isoformat()},
is_recurring=True,
recurrence_rule="yearly",
)
self._session.add(event)
await self._session.flush()
logger.info(f"Scheduled birthday for user {user.id}: {birthday}")
return event
def _extract_birthday(self, message: str) -> datetime | None:
"""Extract a birthday date from a message."""
message_lower = message.lower()
# Check if it's about their birthday
birthday_indicators = [
r"my birthday is",
r"my bday is",
r"i was born on",
r"born on",
r"my birthday'?s?",
]
has_birthday_mention = any(
re.search(pattern, message_lower) for pattern in birthday_indicators
)
if not has_birthday_mention:
return None
# Try to extract date patterns
# Format: Month Day (e.g., "March 15", "march 15th")
month_names = {
"january": 1,
"february": 2,
"march": 3,
"april": 4,
"may": 5,
"june": 6,
"july": 7,
"august": 8,
"september": 9,
"october": 10,
"november": 11,
"december": 12,
"jan": 1,
"feb": 2,
"mar": 3,
"apr": 4,
"jun": 6,
"jul": 7,
"aug": 8,
"sep": 9,
"oct": 10,
"nov": 11,
"dec": 12,
}
for month_name, month_num in month_names.items():
pattern = rf"{month_name}\s+(\d{{1,2}})"
match = re.search(pattern, message_lower)
if match:
day = int(match.group(1))
if 1 <= day <= 31:
try:
return datetime(2000, month_num, day) # Year doesn't matter
except ValueError:
pass
# Format: DD/MM or MM/DD
date_pattern = r"(\d{1,2})[/\-](\d{1,2})"
match = re.search(date_pattern, message)
if match:
n1, n2 = int(match.group(1)), int(match.group(2))
# Assume MM/DD if first number <= 12, else DD/MM
if n1 <= 12 and n2 <= 31:
try:
return datetime(2000, n1, n2)
except ValueError:
pass
elif n2 <= 12 and n1 <= 31:
try:
return datetime(2000, n2, n1)
except ValueError:
pass
return None
def _next_birthday(self, birthday: datetime) -> datetime:
"""Calculate the next occurrence of a birthday."""
today = datetime.now(timezone.utc).date()
this_year = birthday.replace(year=today.year)
if this_year.date() < today:
return birthday.replace(year=today.year + 1)
return this_year
async def _get_existing_birthday(
self, user_id: int, guild_id: int | None
) -> ScheduledEvent | None:
"""Check if a birthday is already scheduled."""
stmt = select(ScheduledEvent).where(
ScheduledEvent.user_id == user_id,
ScheduledEvent.guild_id == guild_id,
ScheduledEvent.event_type == "birthday",
ScheduledEvent.status == "pending",
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def get_pending_events(self, before: datetime | None = None) -> list[ScheduledEvent]:
"""Get events that should be triggered."""
cutoff = before or datetime.now(timezone.utc)
stmt = (
select(ScheduledEvent)
.where(
ScheduledEvent.status == "pending",
ScheduledEvent.trigger_at <= cutoff,
)
.order_by(ScheduledEvent.trigger_at)
)
result = await self._session.execute(stmt)
return list(result.scalars().all())
async def generate_event_message(self, event: ScheduledEvent) -> str:
"""Generate the message for a triggered event."""
if event.event_type == "birthday":
return await self._generate_birthday_message(event)
elif event.event_type == "follow_up":
return await self._generate_followup_message(event)
else:
return await self._generate_generic_message(event)
async def _generate_birthday_message(self, event: ScheduledEvent) -> str:
"""Generate a birthday message."""
if self._ai_service:
try:
response = await self._ai_service.chat(
messages=[Message(role="user", content="Generate a birthday message")],
system_prompt=(
"Generate a warm, personalized birthday wish. "
"Be genuine but not over the top. Keep it to 1-2 sentences. "
"Don't use too many emojis."
),
)
return response.content
except Exception:
pass
# Fallback
return "Happy birthday! Hope you have an amazing day!"
async def _generate_followup_message(self, event: ScheduledEvent) -> str:
"""Generate a follow-up message."""
topic = event.context.get("original_topic", "that thing you mentioned")
if self._ai_service:
try:
response = await self._ai_service.chat(
messages=[Message(role="user", content=f"Follow up about: {topic}")],
system_prompt=(
f"Generate a natural follow-up question about '{topic}'. "
"Be casual and genuinely curious. Ask how it went. "
"Keep it to 1-2 sentences. No emojis."
),
)
return response.content
except Exception:
pass
# Fallback
return f"Hey! How did {topic} go?"
async def _generate_generic_message(self, event: ScheduledEvent) -> str:
"""Generate a generic event message."""
return f"Hey! Just wanted to check in - {event.title}"
async def mark_event_triggered(self, event: ScheduledEvent) -> None:
"""Mark an event as triggered and handle recurrence."""
event.status = "triggered"
event.triggered_at = datetime.now(timezone.utc)
# Handle recurring events
if event.is_recurring and event.recurrence_rule:
await self._schedule_next_occurrence(event)
async def _schedule_next_occurrence(self, event: ScheduledEvent) -> None:
"""Schedule the next occurrence of a recurring event."""
if event.recurrence_rule == "yearly":
next_trigger = event.trigger_at.replace(year=event.trigger_at.year + 1)
elif event.recurrence_rule == "monthly":
# Add one month
month = event.trigger_at.month + 1
year = event.trigger_at.year
if month > 12:
month = 1
year += 1
next_trigger = event.trigger_at.replace(year=year, month=month)
elif event.recurrence_rule == "weekly":
next_trigger = event.trigger_at + timedelta(weeks=1)
else:
return # Unknown rule
new_event = ScheduledEvent(
user_id=event.user_id,
guild_id=event.guild_id,
channel_id=event.channel_id,
event_type=event.event_type,
trigger_at=next_trigger,
title=event.title,
context=event.context,
is_recurring=True,
recurrence_rule=event.recurrence_rule,
)
self._session.add(new_event)
async def cancel_event(self, event_id: int) -> bool:
"""Cancel a scheduled event."""
stmt = select(ScheduledEvent).where(ScheduledEvent.id == event_id)
result = await self._session.execute(stmt)
event = result.scalar_one_or_none()
if event and event.status == "pending":
event.status = "cancelled"
return True
return False
def _parse_json_response(self, response: str) -> dict | None:
"""Parse JSON from AI response."""
try:
response = response.strip()
if "```json" in response:
start = response.find("```json") + 7
end = response.find("```", start)
response = response[start:end].strip()
elif "```" in response:
start = response.find("```") + 3
end = response.find("```", start)
response = response[start:end].strip()
return json.loads(response)
except json.JSONDecodeError:
return None