From 7b8ad2e3159a8873f86ab875b0f32c57217a8274 Mon Sep 17 00:00:00 2001 From: LINSTCL Date: Thu, 2 Mar 2023 16:47:50 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=A8=A1=E5=9E=8B=E5=88=87?= =?UTF-8?q?=E6=8D=A2=E8=A7=92=E8=89=B2=E6=94=B9=E5=8F=98=E5=BC=95=E8=B5=B7?= =?UTF-8?q?=E7=9A=84BUG?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/openai/modelmgr.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pkg/openai/modelmgr.py b/pkg/openai/modelmgr.py index 62540e6b..15e50cce 100644 --- a/pkg/openai/modelmgr.py +++ b/pkg/openai/modelmgr.py @@ -47,6 +47,7 @@ class Model(): return self.ret class ChatCompletionModel(Model): + Chat_role = ['system', 'user', 'assistant'] def __init__(self, model_name, user_name): request_fun = openai.ChatCompletion.create self.can_chat = True @@ -57,6 +58,14 @@ class ChatCompletionModel(Model): self.ret = self.ret_handle(ret) self.message = self.ret["choices"][0]["message"]['content'] + def msg_handle(self, msgs): + temp_msgs = [] + for msg in msgs: + if msg['role'] not in self.Chat_role: + msg['role'] = 'user' + temp_msgs.append(msg) + return temp_msgs + def get_content(self): return self.message @@ -77,7 +86,6 @@ class CompletionModel(Model): prompt = prompt + "{}\n".format(msg['content']) else: prompt = prompt + "{}:{}\n".format(msg['role'] if msg['role']!='system' else '你的回答要遵守此规则', msg['content']) - print(prompt) return prompt def get_text(self):