AI implementation for openai and claude.
This commit is contained in:
48
backend/app/services/provider_openai.py
Normal file
48
backend/app/services/provider_openai.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from .provider_base import AIProvider
|
||||
|
||||
|
||||
class OpenAIProvider(AIProvider):
|
||||
def __init__(self, api_key: str, model: str):
|
||||
super().__init__(api_key, model)
|
||||
self.client = AsyncOpenAI(api_key=api_key)
|
||||
|
||||
async def chat(self, message: str, system_prompt: str = None) -> str:
|
||||
"""Non-streaming chat"""
|
||||
messages = []
|
||||
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
messages.append({"role": "user", "content": message})
|
||||
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model, messages=messages, max_tokens=4000
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
async def chat_stream(
|
||||
self, message: str, system_prompt: str = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Streaming chat"""
|
||||
messages = []
|
||||
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
messages.append({"role": "user", "content": message})
|
||||
|
||||
stream = await self.client.chat.completions.create(
|
||||
model=self.model, messages=messages, max_tokens=4000, stream=True
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
def get_provider_name(self) -> str:
|
||||
return "openai"
|
||||
Reference in New Issue
Block a user