From 77df3d1ae5a31d755d325bb15754b70164c46e79 Mon Sep 17 00:00:00 2001 From: LINSTCL Date: Thu, 2 Mar 2023 23:50:51 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BD=BF=E7=94=A8=E6=96=87?= =?UTF-8?q?=E6=9C=AC=E5=AE=8C=E6=88=90=E6=A8=A1=E5=9E=8B=E7=94=9F=E6=88=90?= =?UTF-8?q?=E5=AF=B9=E8=AF=9D=E5=9E=8B=E6=96=87=E6=9C=AC=E6=97=B6=E8=BE=93?= =?UTF-8?q?=E5=87=BA=E9=9A=8F=E6=9C=BAAI=E5=90=8D=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/openai/modelmgr.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/pkg/openai/modelmgr.py b/pkg/openai/modelmgr.py index 66b3c080..8c94dc35 100644 --- a/pkg/openai/modelmgr.py +++ b/pkg/openai/modelmgr.py @@ -42,8 +42,8 @@ class ModelRequest(): def msg_handle(self, msg): return msg - def ret_handle(self, ret): - return ret + def ret_handle(self): + return def get_total_tokens(self): return self.ret['usage']['total_tokens'] @@ -64,8 +64,8 @@ class ChatCompletionModel(ModelRequest): super().__init__(model_name, user_name, request_fun) def request(self, messages, **kwargs): - ret = self.request_fun(messages = self.msg_handle(messages), **kwargs, user=self.user_name) - self.ret = self.ret_handle(ret) + self.ret = self.request_fun(messages = self.msg_handle(messages), **kwargs, user=self.user_name) + self.ret_handle() self.message = self.ret["choices"][0]["message"]['content'] def msg_handle(self, msgs): @@ -87,8 +87,8 @@ class CompletionModel(ModelRequest): super().__init__(model_name, user_name, request_fun) def request(self, prompt, **kwargs): - ret = self.request_fun(prompt = self.msg_handle(prompt), **kwargs) - self.ret = self.ret_handle(ret) + self.ret = self.request_fun(prompt = self.msg_handle(prompt), **kwargs) + self.ret_handle() self.message = self.ret["choices"][0]["text"] def msg_handle(self, msgs): @@ -100,6 +100,15 @@ class CompletionModel(ModelRequest): prompt = prompt + "{}:{}\n".format(msg['role'] , msg['content']) return prompt + def ret_handle(self): + temp_text:str = self.ret["choices"][0]["text"] + texts = temp_text.split(':') + if len(texts) >= 1: + temp_text = "" + for text in texts[1:]: + temp_text = temp_text + text + self.ret["choices"][0]["text"] = temp_text + def get_text(self): return self.message