From 9b77a2221b891ad8c25f27443a64f61a16abb242 Mon Sep 17 00:00:00 2001 From: Rock Chin Date: Fri, 9 Dec 2022 16:17:04 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=88=AA=E5=8F=96prompt=E6=8F=90?= =?UTF-8?q?=E4=BA=A4=E5=88=B0API=20#4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/openai/session.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 633ec434..3e653701 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -73,8 +73,11 @@ class Session: self.last_interact_timestamp = int(time.time()) # 向API请求补全 - response = pkg.openai.manager.get_inst().request_completion(self.prompt, self.user_name + ':') + response = pkg.openai.manager.get_inst().request_completion(self.cut_out(self.prompt + self.user_name + ':' + + text + '\n' + self.bot_name + ':', + 7, 1024), self.user_name + ':') + self.prompt += self.user_name + ':' + text + '\n' + self.bot_name + ':' # print(response) # 处理回复 res_test = response["choices"][0]["text"] @@ -94,6 +97,29 @@ class Session: return res_ans + # 截取prompt里不多于max_rounds个回合,长度为大于max_tokens的最小整数字符串 + # 保证都是完整的对话 + def cut_out(self, prompt: str, max_rounds: int, max_tokens: int) -> str: + # 分隔出每个回合 + rounds_spt_by_user_name = prompt.split(self.user_name + ':') + + result = '' + + checked_rounds = 0 + # 从后往前遍历,加到result前面,检查result是否符合要求 + for i in range(len(rounds_spt_by_user_name) - 1, 0, -1): + result = self.user_name + ':' + rounds_spt_by_user_name[i] + result + checked_rounds += 1 + + if checked_rounds >= max_rounds: + break + + if len(result) > max_tokens: + break + + logging.debug('cut_out: {}'.format(result)) + return result + def persistence(self): if self.prompt == '': return