357 lines
12 KiB
Python
357 lines
12 KiB
Python
"""Fact Extraction Service - autonomous extraction of facts from conversations."""
|
|
|
|
import json
|
|
import logging
|
|
import random
|
|
from datetime import datetime, timezone
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from daemon_boyfriend.config import settings
|
|
from daemon_boyfriend.models import User, UserFact
|
|
|
|
from .providers import Message
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FactExtractionService:
|
|
"""Autonomous extraction of facts from conversations."""
|
|
|
|
# Minimum message length to consider for extraction
|
|
MIN_MESSAGE_LENGTH = 20
|
|
|
|
# Maximum facts to extract per message
|
|
MAX_FACTS_PER_MESSAGE = 3
|
|
|
|
def __init__(self, session: AsyncSession, ai_service=None) -> None:
|
|
self._session = session
|
|
self._ai_service = ai_service
|
|
|
|
async def maybe_extract_facts(
|
|
self,
|
|
user: User,
|
|
message_content: str,
|
|
discord_message_id: int | None = None,
|
|
) -> list[UserFact]:
|
|
"""Maybe extract facts from a message based on rate limiting.
|
|
|
|
Args:
|
|
user: The user who sent the message
|
|
message_content: The message content
|
|
discord_message_id: Optional Discord message ID for reference
|
|
|
|
Returns:
|
|
List of newly extracted facts (may be empty)
|
|
"""
|
|
if not settings.fact_extraction_enabled:
|
|
return []
|
|
|
|
# Rate limit: only extract from a percentage of messages
|
|
if random.random() > settings.fact_extraction_rate:
|
|
return []
|
|
|
|
return await self.extract_facts(user, message_content, discord_message_id)
|
|
|
|
async def extract_facts(
|
|
self,
|
|
user: User,
|
|
message_content: str,
|
|
discord_message_id: int | None = None,
|
|
) -> list[UserFact]:
|
|
"""Extract facts from a message.
|
|
|
|
Args:
|
|
user: The user who sent the message
|
|
message_content: The message content
|
|
discord_message_id: Optional Discord message ID for reference
|
|
|
|
Returns:
|
|
List of newly extracted facts
|
|
"""
|
|
# Skip messages that are too short or likely not informative
|
|
if not self._is_extractable(message_content):
|
|
return []
|
|
|
|
if not self._ai_service:
|
|
logger.warning("No AI service available for fact extraction")
|
|
return []
|
|
|
|
try:
|
|
# Get existing facts to avoid duplicates
|
|
existing_facts = await self._get_user_facts(user)
|
|
existing_summary = self._summarize_existing_facts(existing_facts)
|
|
|
|
# Build extraction prompt
|
|
extraction_prompt = self._build_extraction_prompt(existing_summary)
|
|
|
|
# Use AI to extract facts
|
|
response = await self._ai_service.chat(
|
|
messages=[Message(role="user", content=message_content)],
|
|
system_prompt=extraction_prompt,
|
|
)
|
|
|
|
# Parse extracted facts
|
|
facts_data = self._parse_extraction_response(response.content)
|
|
|
|
if not facts_data:
|
|
return []
|
|
|
|
# Deduplicate and save new facts
|
|
new_facts = await self._save_new_facts(
|
|
user=user,
|
|
facts_data=facts_data,
|
|
existing_facts=existing_facts,
|
|
discord_message_id=discord_message_id,
|
|
extraction_context=message_content[:200],
|
|
)
|
|
|
|
if new_facts:
|
|
logger.info(f"Extracted {len(new_facts)} facts for user {user.discord_id}")
|
|
|
|
return new_facts
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Fact extraction failed: {e}")
|
|
return []
|
|
|
|
def _is_extractable(self, content: str) -> bool:
|
|
"""Check if a message is worth extracting facts from."""
|
|
# Too short
|
|
if len(content) < self.MIN_MESSAGE_LENGTH:
|
|
return False
|
|
|
|
# Just emoji or symbols
|
|
alpha_ratio = sum(c.isalpha() for c in content) / max(len(content), 1)
|
|
if alpha_ratio < 0.5:
|
|
return False
|
|
|
|
# Looks like a command
|
|
if content.startswith(("!", "/", "?", ".")):
|
|
return False
|
|
|
|
# Just a greeting or very short phrase
|
|
short_phrases = [
|
|
"hi",
|
|
"hello",
|
|
"hey",
|
|
"yo",
|
|
"sup",
|
|
"bye",
|
|
"goodbye",
|
|
"thanks",
|
|
"thank you",
|
|
"ok",
|
|
"okay",
|
|
"yes",
|
|
"no",
|
|
"yeah",
|
|
"nah",
|
|
"lol",
|
|
"lmao",
|
|
"haha",
|
|
"hehe",
|
|
"nice",
|
|
"cool",
|
|
"wow",
|
|
]
|
|
content_lower = content.lower().strip()
|
|
if content_lower in short_phrases:
|
|
return False
|
|
|
|
return True
|
|
|
|
def _build_extraction_prompt(self, existing_summary: str) -> str:
|
|
"""Build the extraction prompt for the AI."""
|
|
return f"""You are a fact extraction assistant. Extract factual information about the user from their message.
|
|
|
|
ALREADY KNOWN FACTS:
|
|
{existing_summary if existing_summary else "(None yet)"}
|
|
|
|
RULES:
|
|
1. Only extract CONCRETE facts, not opinions or transient states
|
|
2. Skip if the fact is already known (listed above)
|
|
3. Skip greetings, questions, or meta-conversation
|
|
4. Skip vague statements like "I like stuff" - be specific
|
|
5. Focus on: hobbies, work, family, preferences, locations, events, relationships
|
|
6. Keep fact content concise (under 100 characters)
|
|
7. Maximum {self.MAX_FACTS_PER_MESSAGE} facts per message
|
|
|
|
OUTPUT FORMAT:
|
|
Return a JSON array of facts, or empty array [] if no extractable facts.
|
|
Each fact should have:
|
|
- "type": one of "hobby", "work", "family", "preference", "location", "event", "relationship", "general"
|
|
- "content": the fact itself (concise, third person, e.g., "loves hiking")
|
|
- "confidence": 0.6 (implied), 0.8 (stated), 1.0 (explicit)
|
|
- "importance": 0.3 (trivial), 0.5 (normal), 0.8 (significant), 1.0 (very important)
|
|
- "temporal": "past", "present", "future", or "timeless"
|
|
|
|
EXAMPLE INPUT: "I just got promoted to senior engineer at Google last week!"
|
|
EXAMPLE OUTPUT: [{{"type": "work", "content": "works as senior engineer at Google", "confidence": 1.0, "importance": 0.8, "temporal": "present"}}, {{"type": "event", "content": "recently got promoted", "confidence": 1.0, "importance": 0.7, "temporal": "past"}}]
|
|
|
|
EXAMPLE INPUT: "hey what's up"
|
|
EXAMPLE OUTPUT: []
|
|
|
|
Return ONLY the JSON array, no other text."""
|
|
|
|
def _parse_extraction_response(self, response: str) -> list[dict]:
|
|
"""Parse the AI response into fact dictionaries."""
|
|
try:
|
|
# Try to find JSON array in the response
|
|
response = response.strip()
|
|
|
|
# Handle markdown code blocks
|
|
if "```json" in response:
|
|
start = response.find("```json") + 7
|
|
end = response.find("```", start)
|
|
response = response[start:end].strip()
|
|
elif "```" in response:
|
|
start = response.find("```") + 3
|
|
end = response.find("```", start)
|
|
response = response[start:end].strip()
|
|
|
|
# Parse JSON
|
|
facts = json.loads(response)
|
|
|
|
if not isinstance(facts, list):
|
|
return []
|
|
|
|
# Validate each fact
|
|
valid_facts = []
|
|
for fact in facts[: self.MAX_FACTS_PER_MESSAGE]:
|
|
if self._validate_fact(fact):
|
|
valid_facts.append(fact)
|
|
|
|
return valid_facts
|
|
|
|
except json.JSONDecodeError:
|
|
logger.debug(f"Failed to parse fact extraction response: {response[:100]}")
|
|
return []
|
|
|
|
def _validate_fact(self, fact: dict) -> bool:
|
|
"""Validate a fact dictionary."""
|
|
required_fields = ["type", "content"]
|
|
valid_types = [
|
|
"hobby",
|
|
"work",
|
|
"family",
|
|
"preference",
|
|
"location",
|
|
"event",
|
|
"relationship",
|
|
"general",
|
|
]
|
|
|
|
# Check required fields
|
|
if not all(field in fact for field in required_fields):
|
|
return False
|
|
|
|
# Check type is valid
|
|
if fact.get("type") not in valid_types:
|
|
return False
|
|
|
|
# Check content is not empty
|
|
if not fact.get("content") or len(fact["content"]) < 3:
|
|
return False
|
|
|
|
# Check content is not too long
|
|
if len(fact["content"]) > 200:
|
|
return False
|
|
|
|
return True
|
|
|
|
async def _get_user_facts(self, user: User) -> list[UserFact]:
|
|
"""Get existing facts for a user."""
|
|
stmt = (
|
|
select(UserFact)
|
|
.where(UserFact.user_id == user.id, UserFact.is_active == True)
|
|
.order_by(UserFact.learned_at.desc())
|
|
.limit(50)
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
def _summarize_existing_facts(self, facts: list[UserFact]) -> str:
|
|
"""Summarize existing facts for the extraction prompt."""
|
|
if not facts:
|
|
return ""
|
|
|
|
summary_lines = []
|
|
for fact in facts[:20]: # Limit to most recent 20
|
|
summary_lines.append(f"- [{fact.fact_type}] {fact.fact_content}")
|
|
|
|
return "\n".join(summary_lines)
|
|
|
|
async def _save_new_facts(
|
|
self,
|
|
user: User,
|
|
facts_data: list[dict],
|
|
existing_facts: list[UserFact],
|
|
discord_message_id: int | None,
|
|
extraction_context: str,
|
|
) -> list[UserFact]:
|
|
"""Save new facts, avoiding duplicates."""
|
|
# Build set of existing fact content for deduplication
|
|
existing_content = {f.fact_content.lower() for f in existing_facts}
|
|
|
|
new_facts = []
|
|
for fact_data in facts_data:
|
|
content = fact_data["content"]
|
|
|
|
# Skip if too similar to existing
|
|
if self._is_duplicate(content, existing_content):
|
|
continue
|
|
|
|
# Create new fact
|
|
fact = UserFact(
|
|
user_id=user.id,
|
|
fact_type=fact_data["type"],
|
|
fact_content=content,
|
|
confidence=fact_data.get("confidence", 0.8),
|
|
source="auto_extraction",
|
|
is_active=True,
|
|
learned_at=datetime.now(timezone.utc),
|
|
# New fields from Living AI
|
|
category=fact_data["type"],
|
|
importance=fact_data.get("importance", 0.5),
|
|
temporal_relevance=fact_data.get("temporal", "timeless"),
|
|
extracted_from_message_id=discord_message_id,
|
|
extraction_context=extraction_context,
|
|
)
|
|
|
|
self._session.add(fact)
|
|
new_facts.append(fact)
|
|
existing_content.add(content.lower())
|
|
|
|
if new_facts:
|
|
await self._session.flush()
|
|
|
|
return new_facts
|
|
|
|
def _is_duplicate(self, new_content: str, existing_content: set[str]) -> bool:
|
|
"""Check if a fact is a duplicate of existing facts."""
|
|
new_lower = new_content.lower()
|
|
|
|
# Exact match
|
|
if new_lower in existing_content:
|
|
return True
|
|
|
|
# Check for high similarity (simple substring check)
|
|
for existing in existing_content:
|
|
# If one contains the other (with some buffer)
|
|
if len(new_lower) > 10 and len(existing) > 10:
|
|
if new_lower in existing or existing in new_lower:
|
|
return True
|
|
|
|
# Simple word overlap check
|
|
new_words = set(new_lower.split())
|
|
existing_words = set(existing.split())
|
|
if len(new_words) > 2 and len(existing_words) > 2:
|
|
overlap = len(new_words & existing_words)
|
|
min_len = min(len(new_words), len(existing_words))
|
|
if overlap / min_len > 0.7: # 70% word overlap
|
|
return True
|
|
|
|
return False
|