mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat: 截取prompt提交到API #4
This commit is contained in:
@@ -73,8 +73,11 @@ class Session:
|
|||||||
self.last_interact_timestamp = int(time.time())
|
self.last_interact_timestamp = int(time.time())
|
||||||
|
|
||||||
# 向API请求补全
|
# 向API请求补全
|
||||||
response = pkg.openai.manager.get_inst().request_completion(self.prompt, self.user_name + ':')
|
response = pkg.openai.manager.get_inst().request_completion(self.cut_out(self.prompt + self.user_name + ':' +
|
||||||
|
text + '\n' + self.bot_name + ':',
|
||||||
|
7, 1024), self.user_name + ':')
|
||||||
|
|
||||||
|
self.prompt += self.user_name + ':' + text + '\n' + self.bot_name + ':'
|
||||||
# print(response)
|
# print(response)
|
||||||
# 处理回复
|
# 处理回复
|
||||||
res_test = response["choices"][0]["text"]
|
res_test = response["choices"][0]["text"]
|
||||||
@@ -94,6 +97,29 @@ class Session:
|
|||||||
|
|
||||||
return res_ans
|
return res_ans
|
||||||
|
|
||||||
|
# 截取prompt里不多于max_rounds个回合,长度为大于max_tokens的最小整数字符串
|
||||||
|
# 保证都是完整的对话
|
||||||
|
def cut_out(self, prompt: str, max_rounds: int, max_tokens: int) -> str:
|
||||||
|
# 分隔出每个回合
|
||||||
|
rounds_spt_by_user_name = prompt.split(self.user_name + ':')
|
||||||
|
|
||||||
|
result = ''
|
||||||
|
|
||||||
|
checked_rounds = 0
|
||||||
|
# 从后往前遍历,加到result前面,检查result是否符合要求
|
||||||
|
for i in range(len(rounds_spt_by_user_name) - 1, 0, -1):
|
||||||
|
result = self.user_name + ':' + rounds_spt_by_user_name[i] + result
|
||||||
|
checked_rounds += 1
|
||||||
|
|
||||||
|
if checked_rounds >= max_rounds:
|
||||||
|
break
|
||||||
|
|
||||||
|
if len(result) > max_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
logging.debug('cut_out: {}'.format(result))
|
||||||
|
return result
|
||||||
|
|
||||||
def persistence(self):
|
def persistence(self):
|
||||||
if self.prompt == '':
|
if self.prompt == '':
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user