From dd75f98d85d8ef1d09e3e679c70d9e2950ece775 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Fri, 4 Aug 2023 18:41:04 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=B8=96=E7=95=8C=E4=B8=8A=E6=9C=80?= =?UTF-8?q?=E5=85=88=E8=BF=9B=E7=9A=84=E8=B0=83=E7=94=A8=E6=B5=81=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/openai/api/chat_completion.py | 18 +++++++++++------- pkg/openai/session.py | 20 ++++++++++++-------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/pkg/openai/api/chat_completion.py b/pkg/openai/api/chat_completion.py index 4c375f7e..1ea46ac6 100644 --- a/pkg/openai/api/chat_completion.py +++ b/pkg/openai/api/chat_completion.py @@ -30,7 +30,7 @@ class ChatCompletionRequest(RequestBase): ) self.pending_msg = "" - def append_message(self, role: str, content: str, name: str=None): + def append_message(self, role: str, content: str, name: str=None, function_call: dict=None): msg = { "role": role, "content": content @@ -39,6 +39,9 @@ class ChatCompletionRequest(RequestBase): if name is not None: msg['name'] = name + if function_call is not None: + msg['function_call'] = function_call + self.messages.append(msg) def __init__( @@ -87,16 +90,17 @@ class ChatCompletionRequest(RequestBase): choice0 = resp["choices"][0] # 如果不是函数调用,且finish_reason为stop,则停止迭代 - if 'function_call' not in choice0['message']: # and choice0["finish_reason"] == "stop" + if choice0['finish_reason'] == 'stop': # and choice0["finish_reason"] == "stop" self.stopped = True if 'function_call' in choice0['message']: self.pending_func_call = choice0['message']['function_call'] - # self.append_message( - # role="assistant", - # content="function call: "+json.dumps(self.pending_func_call, ensure_ascii=False) - # ) + self.append_message( + role="assistant", + content=None, + function_call=choice0['message']['function_call'] + ) return { "id": resp["id"], @@ -106,7 +110,7 @@ class ChatCompletionRequest(RequestBase): "message": { "role": "assistant", "type": "function_call", - "content": None, + "content": choice0['message']['content'], "function_call": choice0['message']['function_call'] }, "finish_reason": "function_call" diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 1ab3fd88..2b8a145b 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -259,17 +259,21 @@ class Session: finish_reason = resp['choices'][0]['finish_reason'] - if resp['choices'][0]['message']['type'] == 'text': # 普通回复 - res_text += resp['choices'][0]['message']['content'] + if resp['choices'][0]['message']['role'] == "assistant" and resp['choices'][0]['message']['content'] != None: # 包含纯文本响应 + + res_text += resp['choices'][0]['message']['content'] + "\n" total_tokens += resp['usage']['total_tokens'] - pending_msgs.append( - { - "role": "assistant", - "content": resp['choices'][0]['message']['content'] - } - ) + msg = { + "role": "assistant", + "content": resp['choices'][0]['message']['content'] + } + + if 'function_call' in resp['choices'][0]['message']: + msg['function_call'] = resp['choices'][0]['message']['function_call'] + + pending_msgs.append(msg) elif resp['choices'][0]['message']['type'] == 'function_call': # self.prompt.append(