76 lines
2.5 KiB
Python
76 lines
2.5 KiB
Python
from typing import Optional
|
|
|
|
from ..config import settings
|
|
from .provider_base import AIProvider
|
|
from .provider_claude import ClaudeProvider
|
|
from .provider_openai import OpenAIProvider
|
|
|
|
|
|
class ProviderManager:
|
|
"""Manages provider selection and fallback logic"""
|
|
|
|
def __init__(self):
|
|
self.providers = {}
|
|
self._initialize_providers()
|
|
|
|
def _initialize_providers(self):
|
|
"""Initialize available providers based on API keys"""
|
|
if settings.ANTHROPIC_API_KEY and settings.ANTHROPIC_API_KEY.strip():
|
|
self.providers["claude"] = ClaudeProvider(
|
|
api_key=settings.ANTHROPIC_API_KEY, model=settings.CLAUDE_MODEL
|
|
)
|
|
|
|
if settings.OPENAI_API_KEY and settings.OPENAI_API_KEY.strip():
|
|
self.providers["openai"] = OpenAIProvider(
|
|
api_key=settings.OPENAI_API_KEY, model=settings.OPENAI_MODEL
|
|
)
|
|
|
|
def get_provider(self, provider_name: Optional[str] = None) -> AIProvider:
|
|
"""
|
|
Get a provider by name, or use default.
|
|
Raises ValueError if provider not available.
|
|
"""
|
|
name = provider_name or settings.DEFAULT_PROVIDER
|
|
|
|
if name not in self.providers:
|
|
raise ValueError(
|
|
f"Provider '{name}' not available. "
|
|
f"Available: {list(self.providers.keys())}"
|
|
)
|
|
|
|
return self.providers[name]
|
|
|
|
def get_available_providers(self) -> list[str]:
|
|
"""Return list of available provider names"""
|
|
return list(self.providers.keys())
|
|
|
|
async def chat_with_fallback(
|
|
self, message: str, preferred_provider: Optional[str] = None
|
|
) -> tuple[str, str]:
|
|
"""
|
|
Try to chat with preferred provider, fallback to others if it fails.
|
|
Returns (response, provider_used)
|
|
"""
|
|
providers_to_try = [preferred_provider or settings.DEFAULT_PROVIDER] + [
|
|
p
|
|
for p in self.providers.keys()
|
|
if p != (preferred_provider or settings.DEFAULT_PROVIDER)
|
|
]
|
|
|
|
last_error = None
|
|
|
|
for provider_name in providers_to_try:
|
|
try:
|
|
provider = self.get_provider(provider_name)
|
|
response = await provider.chat(message)
|
|
return response, provider_name
|
|
except Exception as e:
|
|
last_error = e
|
|
continue
|
|
|
|
raise Exception(f"All providers failed. Last error: {last_error}")
|
|
|
|
|
|
# Singleton instance
|
|
provider_manager = ProviderManager()
|