diff --git a/pkg/openai/modelmgr.py b/pkg/openai/modelmgr.py index 62540e6b..15e50cce 100644 --- a/pkg/openai/modelmgr.py +++ b/pkg/openai/modelmgr.py @@ -47,6 +47,7 @@ class Model(): return self.ret class ChatCompletionModel(Model): + Chat_role = ['system', 'user', 'assistant'] def __init__(self, model_name, user_name): request_fun = openai.ChatCompletion.create self.can_chat = True @@ -57,6 +58,14 @@ class ChatCompletionModel(Model): self.ret = self.ret_handle(ret) self.message = self.ret["choices"][0]["message"]['content'] + def msg_handle(self, msgs): + temp_msgs = [] + for msg in msgs: + if msg['role'] not in self.Chat_role: + msg['role'] = 'user' + temp_msgs.append(msg) + return temp_msgs + def get_content(self): return self.message @@ -77,7 +86,6 @@ class CompletionModel(Model): prompt = prompt + "{}\n".format(msg['content']) else: prompt = prompt + "{}:{}\n".format(msg['role'] if msg['role']!='system' else '你的回答要遵守此规则', msg['content']) - print(prompt) return prompt def get_text(self):