from typing import AsyncGenerator from openai import AsyncOpenAI from .provider_base import AIProvider class OpenAIProvider(AIProvider): def __init__(self, api_key: str, model: str): super().__init__(api_key, model) self.client = AsyncOpenAI(api_key=api_key) async def chat(self, message: str, system_prompt: str = None) -> str: """Non-streaming chat""" messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": message}) response = await self.client.chat.completions.create( model=self.model, messages=messages, max_tokens=4000 ) return response.choices[0].message.content async def chat_stream( self, message: str, system_prompt: str = None ) -> AsyncGenerator[str, None]: """Streaming chat""" messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": message}) stream = await self.client.chat.completions.create( model=self.model, messages=messages, max_tokens=4000, stream=True ) async for chunk in stream: if chunk.choices[0].delta.content: yield chunk.choices[0].delta.content def get_provider_name(self) -> str: return "openai"