diff --git a/pkg/openai/modelmgr.py b/pkg/openai/modelmgr.py index 66b3c080..8c94dc35 100644 --- a/pkg/openai/modelmgr.py +++ b/pkg/openai/modelmgr.py @@ -42,8 +42,8 @@ class ModelRequest(): def msg_handle(self, msg): return msg - def ret_handle(self, ret): - return ret + def ret_handle(self): + return def get_total_tokens(self): return self.ret['usage']['total_tokens'] @@ -64,8 +64,8 @@ class ChatCompletionModel(ModelRequest): super().__init__(model_name, user_name, request_fun) def request(self, messages, **kwargs): - ret = self.request_fun(messages = self.msg_handle(messages), **kwargs, user=self.user_name) - self.ret = self.ret_handle(ret) + self.ret = self.request_fun(messages = self.msg_handle(messages), **kwargs, user=self.user_name) + self.ret_handle() self.message = self.ret["choices"][0]["message"]['content'] def msg_handle(self, msgs): @@ -87,8 +87,8 @@ class CompletionModel(ModelRequest): super().__init__(model_name, user_name, request_fun) def request(self, prompt, **kwargs): - ret = self.request_fun(prompt = self.msg_handle(prompt), **kwargs) - self.ret = self.ret_handle(ret) + self.ret = self.request_fun(prompt = self.msg_handle(prompt), **kwargs) + self.ret_handle() self.message = self.ret["choices"][0]["text"] def msg_handle(self, msgs): @@ -100,6 +100,15 @@ class CompletionModel(ModelRequest): prompt = prompt + "{}:{}\n".format(msg['role'] , msg['content']) return prompt + def ret_handle(self): + temp_text:str = self.ret["choices"][0]["text"] + texts = temp_text.split(':') + if len(texts) >= 1: + temp_text = "" + for text in texts[1:]: + temp_text = temp_text + text + self.ret["choices"][0]["text"] = temp_text + def get_text(self): return self.message