AI implementation for openai and claude.
This commit is contained in:
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
27
backend/app/services/provider_base.py
Normal file
27
backend/app/services/provider_base.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncGenerator
|
||||
|
||||
|
||||
class AIProvider(ABC):
|
||||
"""Abstract base class for AI providers"""
|
||||
|
||||
def __init__(self, api_key: str, model: str):
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
|
||||
@abstractmethod
|
||||
async def chat(self, message: str, system_prompt: str = None) -> str:
|
||||
"""Non-streaming chat"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def chat_stream(
|
||||
self, message: str, system_prompt: str = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Streaming chat - yields chunks of text"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_name(self) -> str:
|
||||
"""Return provider identifier"""
|
||||
pass
|
||||
41
backend/app/services/provider_claude.py
Normal file
41
backend/app/services/provider_claude.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import anthropic
|
||||
|
||||
from .provider_base import AIProvider
|
||||
|
||||
|
||||
class ClaudeProvider(AIProvider):
|
||||
def __init__(self, api_key: str, model: str):
|
||||
super().__init__(api_key, model)
|
||||
self.client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||
|
||||
async def chat(self, message: str, system_prompt: str = None) -> str:
|
||||
"""Non-streaming chat"""
|
||||
messages = [{"role": "user", "content": message}]
|
||||
|
||||
kwargs = {"model": self.model, "max_tokens": 4000, "messages": messages}
|
||||
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
|
||||
response = await self.client.messages.create(**kwargs)
|
||||
return response.content[0].text
|
||||
|
||||
async def chat_stream(
|
||||
self, message: str, system_prompt: str = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Streaming chat"""
|
||||
messages = [{"role": "user", "content": message}]
|
||||
|
||||
kwargs = {"model": self.model, "max_tokens": 4000, "messages": messages}
|
||||
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
|
||||
async with self.client.messages.stream(**kwargs) as stream:
|
||||
async for text in stream.text_stream:
|
||||
yield text
|
||||
|
||||
def get_provider_name(self) -> str:
|
||||
return "claude"
|
||||
75
backend/app/services/provider_manager.py
Normal file
75
backend/app/services/provider_manager.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from typing import Optional
|
||||
|
||||
from ..config import settings
|
||||
from .provider_base import AIProvider
|
||||
from .provider_claude import ClaudeProvider
|
||||
from .provider_openai import OpenAIProvider
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
"""Manages provider selection and fallback logic"""
|
||||
|
||||
def __init__(self):
|
||||
self.providers = {}
|
||||
self._initialize_providers()
|
||||
|
||||
def _initialize_providers(self):
|
||||
"""Initialize available providers based on API keys"""
|
||||
if settings.ANTHROPIC_API_KEY and settings.ANTHROPIC_API_KEY.strip():
|
||||
self.providers["claude"] = ClaudeProvider(
|
||||
api_key=settings.ANTHROPIC_API_KEY, model=settings.CLAUDE_MODEL
|
||||
)
|
||||
|
||||
if settings.OPENAI_API_KEY and settings.OPENAI_API_KEY.strip():
|
||||
self.providers["openai"] = OpenAIProvider(
|
||||
api_key=settings.OPENAI_API_KEY, model=settings.OPENAI_MODEL
|
||||
)
|
||||
|
||||
def get_provider(self, provider_name: Optional[str] = None) -> AIProvider:
|
||||
"""
|
||||
Get a provider by name, or use default.
|
||||
Raises ValueError if provider not available.
|
||||
"""
|
||||
name = provider_name or settings.DEFAULT_PROVIDER
|
||||
|
||||
if name not in self.providers:
|
||||
raise ValueError(
|
||||
f"Provider '{name}' not available. "
|
||||
f"Available: {list(self.providers.keys())}"
|
||||
)
|
||||
|
||||
return self.providers[name]
|
||||
|
||||
def get_available_providers(self) -> list[str]:
|
||||
"""Return list of available provider names"""
|
||||
return list(self.providers.keys())
|
||||
|
||||
async def chat_with_fallback(
|
||||
self, message: str, preferred_provider: Optional[str] = None
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Try to chat with preferred provider, fallback to others if it fails.
|
||||
Returns (response, provider_used)
|
||||
"""
|
||||
providers_to_try = [preferred_provider or settings.DEFAULT_PROVIDER] + [
|
||||
p
|
||||
for p in self.providers.keys()
|
||||
if p != (preferred_provider or settings.DEFAULT_PROVIDER)
|
||||
]
|
||||
|
||||
last_error = None
|
||||
|
||||
for provider_name in providers_to_try:
|
||||
try:
|
||||
provider = self.get_provider(provider_name)
|
||||
response = await provider.chat(message)
|
||||
return response, provider_name
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
continue
|
||||
|
||||
raise Exception(f"All providers failed. Last error: {last_error}")
|
||||
|
||||
|
||||
# Singleton instance
|
||||
provider_manager = ProviderManager()
|
||||
48
backend/app/services/provider_openai.py
Normal file
48
backend/app/services/provider_openai.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from .provider_base import AIProvider
|
||||
|
||||
|
||||
class OpenAIProvider(AIProvider):
|
||||
def __init__(self, api_key: str, model: str):
|
||||
super().__init__(api_key, model)
|
||||
self.client = AsyncOpenAI(api_key=api_key)
|
||||
|
||||
async def chat(self, message: str, system_prompt: str = None) -> str:
|
||||
"""Non-streaming chat"""
|
||||
messages = []
|
||||
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
messages.append({"role": "user", "content": message})
|
||||
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model, messages=messages, max_tokens=4000
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
async def chat_stream(
|
||||
self, message: str, system_prompt: str = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Streaming chat"""
|
||||
messages = []
|
||||
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
messages.append({"role": "user", "content": message})
|
||||
|
||||
stream = await self.client.chat.completions.create(
|
||||
model=self.model, messages=messages, max_tokens=4000, stream=True
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
def get_provider_name(self) -> str:
|
||||
return "openai"
|
||||
Reference in New Issue
Block a user