diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 633ec434..3e653701 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -73,8 +73,11 @@ class Session: self.last_interact_timestamp = int(time.time()) # 向API请求补全 - response = pkg.openai.manager.get_inst().request_completion(self.prompt, self.user_name + ':') + response = pkg.openai.manager.get_inst().request_completion(self.cut_out(self.prompt + self.user_name + ':' + + text + '\n' + self.bot_name + ':', + 7, 1024), self.user_name + ':') + self.prompt += self.user_name + ':' + text + '\n' + self.bot_name + ':' # print(response) # 处理回复 res_test = response["choices"][0]["text"] @@ -94,6 +97,29 @@ class Session: return res_ans + # 截取prompt里不多于max_rounds个回合,长度为大于max_tokens的最小整数字符串 + # 保证都是完整的对话 + def cut_out(self, prompt: str, max_rounds: int, max_tokens: int) -> str: + # 分隔出每个回合 + rounds_spt_by_user_name = prompt.split(self.user_name + ':') + + result = '' + + checked_rounds = 0 + # 从后往前遍历,加到result前面,检查result是否符合要求 + for i in range(len(rounds_spt_by_user_name) - 1, 0, -1): + result = self.user_name + ':' + rounds_spt_by_user_name[i] + result + checked_rounds += 1 + + if checked_rounds >= max_rounds: + break + + if len(result) > max_tokens: + break + + logging.debug('cut_out: {}'.format(result)) + return result + def persistence(self): if self.prompt == '': return