mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 12:05:54 +00:00
feat: 适配completion和chat_completions
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import openai
|
||||
from openai.types.chat import chat_completion_message
|
||||
import json
|
||||
import logging
|
||||
|
||||
@@ -13,13 +14,14 @@ class ChatCompletionRequest(RequestBase):
|
||||
此类保证每一次返回的角色为assistant的信息的finish_reason一定为stop。
|
||||
若有函数调用响应,本类的返回瀑布是:函数调用请求->函数调用结果->...->assistant的信息->stop。
|
||||
"""
|
||||
|
||||
model: str
|
||||
messages: list[dict[str, str]]
|
||||
kwargs: dict
|
||||
|
||||
stopped: bool = False
|
||||
|
||||
pending_func_call: dict = None
|
||||
pending_func_call: chat_completion_message.FunctionCall = None
|
||||
|
||||
pending_msg: str
|
||||
|
||||
@@ -46,16 +48,18 @@ class ChatCompletionRequest(RequestBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: openai.Client,
|
||||
model: str,
|
||||
messages: list[dict[str, str]],
|
||||
**kwargs
|
||||
):
|
||||
self.client = client
|
||||
self.model = model
|
||||
self.messages = messages.copy()
|
||||
|
||||
self.kwargs = kwargs
|
||||
|
||||
self.req_func = openai.ChatCompletion.acreate
|
||||
self.req_func = self.client.chat.completions.create
|
||||
|
||||
self.pending_func_call = None
|
||||
|
||||
@@ -84,39 +88,48 @@ class ChatCompletionRequest(RequestBase):
|
||||
|
||||
# 拼接kwargs
|
||||
args = {**args, **self.kwargs}
|
||||
|
||||
from openai.types.chat import chat_completion
|
||||
|
||||
resp = self._req(**args)
|
||||
resp: chat_completion.ChatCompletion = self._req(**args)
|
||||
|
||||
choice0 = resp["choices"][0]
|
||||
choice0 = resp.choices[0]
|
||||
|
||||
# 如果不是函数调用,且finish_reason为stop,则停止迭代
|
||||
if choice0['finish_reason'] == 'stop': # and choice0["finish_reason"] == "stop"
|
||||
if choice0.finish_reason == 'stop': # and choice0["finish_reason"] == "stop"
|
||||
self.stopped = True
|
||||
|
||||
if 'function_call' in choice0['message']:
|
||||
self.pending_func_call = choice0['message']['function_call']
|
||||
if hasattr(choice0.message, 'function_call') and choice0.message.function_call is not None:
|
||||
self.pending_func_call = choice0.message.function_call
|
||||
|
||||
self.append_message(
|
||||
role="assistant",
|
||||
content=choice0['message']['content'],
|
||||
function_call=choice0['message']['function_call']
|
||||
content=choice0.message.content,
|
||||
function_call=choice0.message.function_call
|
||||
)
|
||||
|
||||
return {
|
||||
"id": resp["id"],
|
||||
"id": resp.id,
|
||||
"choices": [
|
||||
{
|
||||
"index": choice0["index"],
|
||||
"index": choice0.index,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"type": "function_call",
|
||||
"content": choice0['message']['content'],
|
||||
"function_call": choice0['message']['function_call']
|
||||
"content": choice0.message.content,
|
||||
"function_call": {
|
||||
"name": choice0.message.function_call.name,
|
||||
"arguments": choice0.message.function_call.arguments
|
||||
}
|
||||
},
|
||||
"finish_reason": "function_call"
|
||||
}
|
||||
],
|
||||
"usage": resp["usage"]
|
||||
"usage": {
|
||||
"prompt_tokens": resp.usage.prompt_tokens,
|
||||
"completion_tokens": resp.usage.completion_tokens,
|
||||
"total_tokens": resp.usage.total_tokens
|
||||
}
|
||||
}
|
||||
else:
|
||||
|
||||
@@ -124,19 +137,23 @@ class ChatCompletionRequest(RequestBase):
|
||||
# 普通回复一定处于最后方,故不用再追加进内部messages
|
||||
|
||||
return {
|
||||
"id": resp["id"],
|
||||
"id": resp.id,
|
||||
"choices": [
|
||||
{
|
||||
"index": choice0["index"],
|
||||
"index": choice0.index,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"type": "text",
|
||||
"content": choice0['message']['content']
|
||||
"content": choice0.message.content
|
||||
},
|
||||
"finish_reason": choice0["finish_reason"]
|
||||
"finish_reason": choice0.finish_reason
|
||||
}
|
||||
],
|
||||
"usage": resp["usage"]
|
||||
"usage": {
|
||||
"prompt_tokens": resp.usage.prompt_tokens,
|
||||
"completion_tokens": resp.usage.completion_tokens,
|
||||
"total_tokens": resp.usage.total_tokens
|
||||
}
|
||||
}
|
||||
else: # 处理函数调用请求
|
||||
|
||||
@@ -144,20 +161,20 @@ class ChatCompletionRequest(RequestBase):
|
||||
|
||||
self.pending_func_call = None
|
||||
|
||||
func_name = cp_pending_func_call['name']
|
||||
func_name = cp_pending_func_call.name
|
||||
arguments = {}
|
||||
|
||||
try:
|
||||
|
||||
try:
|
||||
arguments = json.loads(cp_pending_func_call['arguments'])
|
||||
arguments = json.loads(cp_pending_func_call.arguments)
|
||||
# 若不是json格式的异常处理
|
||||
except json.decoder.JSONDecodeError:
|
||||
# 获取函数的参数列表
|
||||
func_schema = get_func_schema(func_name)
|
||||
|
||||
arguments = {
|
||||
func_schema['parameters']['required'][0]: cp_pending_func_call['arguments']
|
||||
func_schema['parameters']['required'][0]: cp_pending_func_call.arguments
|
||||
}
|
||||
|
||||
logging.info("执行函数调用: name={}, arguments={}".format(func_name, arguments))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import openai
|
||||
from openai.types import completion, completion_choice
|
||||
|
||||
from .model import RequestBase
|
||||
|
||||
@@ -17,10 +18,12 @@ class CompletionRequest(RequestBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: openai.Client,
|
||||
model: str,
|
||||
messages: list[dict[str, str]],
|
||||
**kwargs
|
||||
):
|
||||
self.client = client
|
||||
self.model = model
|
||||
self.prompt = ""
|
||||
|
||||
@@ -31,7 +34,7 @@ class CompletionRequest(RequestBase):
|
||||
|
||||
self.kwargs = kwargs
|
||||
|
||||
self.req_func = openai.Completion.acreate
|
||||
self.req_func = self.client.completions.create
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
@@ -63,49 +66,35 @@ class CompletionRequest(RequestBase):
|
||||
if self.stopped:
|
||||
raise StopIteration()
|
||||
|
||||
resp = self._req(
|
||||
resp: completion.Completion = self._req(
|
||||
model=self.model,
|
||||
prompt=self.prompt,
|
||||
**self.kwargs
|
||||
)
|
||||
|
||||
if resp["choices"][0]["finish_reason"] == "stop":
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
self.stopped = True
|
||||
|
||||
choice0 = resp["choices"][0]
|
||||
choice0: completion_choice.CompletionChoice = resp.choices[0]
|
||||
|
||||
self.prompt += choice0["text"]
|
||||
self.prompt += choice0.text
|
||||
|
||||
return {
|
||||
"id": resp["id"],
|
||||
"id": resp.id,
|
||||
"choices": [
|
||||
{
|
||||
"index": choice0["index"],
|
||||
"index": choice0.index,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"type": "text",
|
||||
"content": choice0["text"]
|
||||
"content": choice0.text
|
||||
},
|
||||
"finish_reason": choice0["finish_reason"]
|
||||
"finish_reason": choice0.finish_reason
|
||||
}
|
||||
],
|
||||
"usage": resp["usage"]
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
openai.api_key = os.environ["OPENAI_API_KEY"]
|
||||
|
||||
for resp in CompletionRequest(
|
||||
model="text-davinci-003",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, who are you?"
|
||||
"usage": {
|
||||
"prompt_tokens": resp.usage.prompt_tokens,
|
||||
"completion_tokens": resp.usage.completion_tokens,
|
||||
"total_tokens": resp.usage.total_tokens
|
||||
}
|
||||
]
|
||||
):
|
||||
print(resp)
|
||||
if resp["choices"][0]["finish_reason"] == "stop":
|
||||
break
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import openai
|
||||
|
||||
class RequestBase:
|
||||
|
||||
client: openai.Client
|
||||
|
||||
req_func: callable
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -17,41 +19,17 @@ class RequestBase:
|
||||
import pkg.utils.context as context
|
||||
switched, name = context.get_openai_manager().key_mgr.auto_switch()
|
||||
logging.debug("切换api-key: switched={}, name={}".format(switched, name))
|
||||
openai.api_key = context.get_openai_manager().key_mgr.get_using_key()
|
||||
self.client.api_key = context.get_openai_manager().key_mgr.get_using_key()
|
||||
|
||||
def _req(self, **kwargs):
|
||||
"""处理代理问题"""
|
||||
import config
|
||||
|
||||
ret: dict = {}
|
||||
exception: Exception = None
|
||||
ret = self.req_func(**kwargs)
|
||||
logging.debug("接口请求返回:%s", str(ret))
|
||||
|
||||
async def awrapper(**kwargs):
|
||||
nonlocal ret, exception
|
||||
|
||||
try:
|
||||
ret = await self.req_func(**kwargs)
|
||||
logging.debug("接口请求返回:%s", str(ret))
|
||||
|
||||
if config.switch_strategy == 'active':
|
||||
self._next_key()
|
||||
|
||||
return ret
|
||||
except Exception as e:
|
||||
exception = e
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
thr = threading.Thread(
|
||||
target=loop.run_until_complete,
|
||||
args=(awrapper(**kwargs),)
|
||||
)
|
||||
|
||||
thr.start()
|
||||
thr.join()
|
||||
|
||||
if exception is not None:
|
||||
raise exception
|
||||
if config.switch_strategy == 'active':
|
||||
self._next_key()
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@@ -24,6 +24,8 @@ class OpenAIInteract:
|
||||
"size": "256x256",
|
||||
}
|
||||
|
||||
client: openai.Client = None
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
|
||||
self.key_mgr = pkg.openai.keymgr.KeysManager(api_key)
|
||||
@@ -31,7 +33,9 @@ class OpenAIInteract:
|
||||
|
||||
# logging.info("文字总使用量:%d", self.audit_mgr.get_total_text_length())
|
||||
|
||||
openai.api_key = self.key_mgr.get_using_key()
|
||||
self.client = openai.Client(
|
||||
api_key=self.key_mgr.get_using_key()
|
||||
)
|
||||
|
||||
pkg.utils.context.set_openai_manager(self)
|
||||
|
||||
@@ -48,7 +52,7 @@ class OpenAIInteract:
|
||||
cp_parmas = config.completion_api_params.copy()
|
||||
del cp_parmas['model']
|
||||
|
||||
request = select_request_cls(model, messages, cp_parmas)
|
||||
request = select_request_cls(self.client, model, messages, cp_parmas)
|
||||
|
||||
# 请求接口
|
||||
for resp in request:
|
||||
|
||||
@@ -6,6 +6,7 @@ Completion - text-davinci-003 等模型
|
||||
此模块封装此两个接口的请求实现,为上层提供统一的调用方式
|
||||
"""
|
||||
import tiktoken
|
||||
import openai
|
||||
|
||||
from pkg.openai.api.model import RequestBase
|
||||
from pkg.openai.api.completion import CompletionRequest
|
||||
@@ -51,11 +52,11 @@ IMAGE_MODELS = {
|
||||
}
|
||||
|
||||
|
||||
def select_request_cls(model_name: str, messages: list, args: dict) -> RequestBase:
|
||||
def select_request_cls(client: openai.Client, model_name: str, messages: list, args: dict) -> RequestBase:
|
||||
if model_name in CHAT_COMPLETION_MODELS:
|
||||
return ChatCompletionRequest(model_name, messages, **args)
|
||||
return ChatCompletionRequest(client, model_name, messages, **args)
|
||||
elif model_name in COMPLETION_MODELS:
|
||||
return CompletionRequest(model_name, messages, **args)
|
||||
return CompletionRequest(client, model_name, messages, **args)
|
||||
raise ValueError("不支持模型[{}],请检查配置文件".format(model_name))
|
||||
|
||||
|
||||
|
||||
@@ -278,7 +278,7 @@ class Session:
|
||||
if resp['choices'][0]['message']['role'] == "assistant" and resp['choices'][0]['message']['content'] != None: # 包含纯文本响应
|
||||
|
||||
if not trace_func_calls:
|
||||
res_text += resp['choices'][0]['message']['content'] + "\n"
|
||||
res_text += resp['choices'][0]['message']['content']
|
||||
else:
|
||||
res_text = resp['choices'][0]['message']['content']
|
||||
pending_res_text = resp['choices'][0]['message']['content']
|
||||
|
||||
Reference in New Issue
Block a user