42 lines
1.2 KiB
Python
42 lines
1.2 KiB
Python
from typing import AsyncGenerator
|
|
|
|
import anthropic
|
|
|
|
from .provider_base import AIProvider
|
|
|
|
|
|
class ClaudeProvider(AIProvider):
|
|
def __init__(self, api_key: str, model: str):
|
|
super().__init__(api_key, model)
|
|
self.client = anthropic.AsyncAnthropic(api_key=api_key)
|
|
|
|
async def chat(self, message: str, system_prompt: str = None) -> str:
|
|
"""Non-streaming chat"""
|
|
messages = [{"role": "user", "content": message}]
|
|
|
|
kwargs = {"model": self.model, "max_tokens": 4000, "messages": messages}
|
|
|
|
if system_prompt:
|
|
kwargs["system"] = system_prompt
|
|
|
|
response = await self.client.messages.create(**kwargs)
|
|
return response.content[0].text
|
|
|
|
async def chat_stream(
|
|
self, message: str, system_prompt: str = None
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Streaming chat"""
|
|
messages = [{"role": "user", "content": message}]
|
|
|
|
kwargs = {"model": self.model, "max_tokens": 4000, "messages": messages}
|
|
|
|
if system_prompt:
|
|
kwargs["system"] = system_prompt
|
|
|
|
async with self.client.messages.stream(**kwargs) as stream:
|
|
async for text in stream.text_stream:
|
|
yield text
|
|
|
|
def get_provider_name(self) -> str:
|
|
return "claude"
|