mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-27 16:04:21 +00:00
Merge pull request #300 from RockChinQ/token-process
[Perf] Tokens相关处理逻辑优化
This commit is contained in:
+25
-2
@@ -82,7 +82,30 @@ default_prompt = {
|
|||||||
# 情景预设格式
|
# 情景预设格式
|
||||||
# 参考值:旧版本方式:default | 完整情景:full_scenario
|
# 参考值:旧版本方式:default | 完整情景:full_scenario
|
||||||
# 旧版本的格式为上述default_prompt中的内容,或prompts目录下的文件名
|
# 旧版本的格式为上述default_prompt中的内容,或prompts目录下的文件名
|
||||||
# 完整情景预设的格式为JSON,在JSON文件中列出对话的每个回合,编写方法见scenario/default-template.json
|
#
|
||||||
|
# 完整情景预设的格式为JSON,在scenario目录下的JSON文件中列出对话的每个回合,编写方法见scenario/default-template.json
|
||||||
|
# 编写方法例如:
|
||||||
|
# {
|
||||||
|
# "prompt": [
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": "之后当我需要帮助时,请说“输入!help获取帮助”"
|
||||||
|
# },{
|
||||||
|
# "role": "assistant",
|
||||||
|
# "content": "好的,当你之后需要帮助时,我会说“输入!help获取帮助”"
|
||||||
|
# },{
|
||||||
|
# "role": "user",
|
||||||
|
# "content": "帮助"
|
||||||
|
# },{
|
||||||
|
# "role": "assistant",
|
||||||
|
# "content": "输入!help获取帮助"
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# 您可以按照上述格式编写自己的情景预设,在prompt中列出对话的每个回合,
|
||||||
|
# role为user或assistant,分别表示用户和机器人的回复
|
||||||
|
# 每个JSON文件是一个情景预设,文件名即为情景预设的名称
|
||||||
preset_mode = "default"
|
preset_mode = "default"
|
||||||
|
|
||||||
# 群内响应规则
|
# 群内响应规则
|
||||||
@@ -139,7 +162,7 @@ encourage_sponsor_at_start = True
|
|||||||
# 每次向OpenAI接口发送对话记录上下文的字符数
|
# 每次向OpenAI接口发送对话记录上下文的字符数
|
||||||
# 最大不超过(4096 - max_tokens)个字符,max_tokens为下方completion_api_params中的max_tokens
|
# 最大不超过(4096 - max_tokens)个字符,max_tokens为下方completion_api_params中的max_tokens
|
||||||
# 注意:较大的prompt_submit_length会导致OpenAI账户额度消耗更快
|
# 注意:较大的prompt_submit_length会导致OpenAI账户额度消耗更快
|
||||||
prompt_submit_length = 1024
|
prompt_submit_length = 2048
|
||||||
|
|
||||||
# OpenAI补全API的参数
|
# OpenAI补全API的参数
|
||||||
# 请在下方填写模型,程序自动选择接口
|
# 请在下方填写模型,程序自动选择接口
|
||||||
|
|||||||
+31
-16
@@ -54,20 +54,27 @@ class DatabaseManager:
|
|||||||
`last_interact_timestamp` bigint not null,
|
`last_interact_timestamp` bigint not null,
|
||||||
`status` varchar(255) not null default 'on_going',
|
`status` varchar(255) not null default 'on_going',
|
||||||
`default_prompt` text not null default '',
|
`default_prompt` text not null default '',
|
||||||
`prompt` text not null
|
`prompt` text not null,
|
||||||
|
`token_counts` text not null default '[]'
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# 检查sessions表是否存在`default_prompt`字段
|
# 检查sessions表是否存在`default_prompt`字段, 检查是否存在`token_counts`字段
|
||||||
self.__execute__("PRAGMA table_info('sessions')")
|
self.__execute__("PRAGMA table_info('sessions')")
|
||||||
columns = self.cursor.fetchall()
|
columns = self.cursor.fetchall()
|
||||||
has_default_prompt = False
|
has_default_prompt = False
|
||||||
|
has_token_counts = False
|
||||||
for field in columns:
|
for field in columns:
|
||||||
if field[1] == 'default_prompt':
|
if field[1] == 'default_prompt':
|
||||||
has_default_prompt = True
|
has_default_prompt = True
|
||||||
|
if field[1] == 'token_counts':
|
||||||
|
has_token_counts = True
|
||||||
|
if has_default_prompt and has_token_counts:
|
||||||
break
|
break
|
||||||
if not has_default_prompt:
|
if not has_default_prompt:
|
||||||
self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''")
|
self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''")
|
||||||
|
if not has_token_counts:
|
||||||
|
self.__execute__("alter table `sessions` add column `token_counts` text not null default '[]'")
|
||||||
|
|
||||||
|
|
||||||
self.__execute__("""
|
self.__execute__("""
|
||||||
@@ -89,7 +96,7 @@ class DatabaseManager:
|
|||||||
|
|
||||||
# session持久化
|
# session持久化
|
||||||
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
|
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
|
||||||
last_interact_timestamp: int, prompt: str, default_prompt: str = ''):
|
last_interact_timestamp: int, prompt: str, default_prompt: str = '', token_counts: str = ''):
|
||||||
"""持久化指定session"""
|
"""持久化指定session"""
|
||||||
|
|
||||||
# 检查是否已经有了此name和create_timestamp的session
|
# 检查是否已经有了此name和create_timestamp的session
|
||||||
@@ -102,20 +109,20 @@ class DatabaseManager:
|
|||||||
if count == 0:
|
if count == 0:
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`)
|
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`, `token_counts`)
|
||||||
values (?, ?, ?, ?, ?, ?, ?)
|
values (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.__execute__(sql,
|
self.__execute__(sql,
|
||||||
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
|
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
|
||||||
last_interact_timestamp, prompt, default_prompt))
|
last_interact_timestamp, prompt, default_prompt, token_counts))
|
||||||
else:
|
else:
|
||||||
sql = """
|
sql = """
|
||||||
update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?
|
update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?, `token_counts` = ?
|
||||||
where `type` = ? and `number` = ? and `create_timestamp` = ?
|
where `type` = ? and `number` = ? and `create_timestamp` = ?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.__execute__(sql, (last_interact_timestamp, prompt, subject_type,
|
self.__execute__(sql, (last_interact_timestamp, prompt, token_counts, subject_type,
|
||||||
subject_number, create_timestamp))
|
subject_number, create_timestamp))
|
||||||
|
|
||||||
# 显式关闭一个session
|
# 显式关闭一个session
|
||||||
@@ -140,7 +147,7 @@ class DatabaseManager:
|
|||||||
# 从数据库中加载所有还没过期的session
|
# 从数据库中加载所有还没过期的session
|
||||||
config = pkg.utils.context.get_config()
|
config = pkg.utils.context.get_config()
|
||||||
self.__execute__("""
|
self.__execute__("""
|
||||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
|
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
|
||||||
from `sessions` where `last_interact_timestamp` > {}
|
from `sessions` where `last_interact_timestamp` > {}
|
||||||
""".format(int(time.time()) - config.session_expire_time))
|
""".format(int(time.time()) - config.session_expire_time))
|
||||||
results = self.cursor.fetchall()
|
results = self.cursor.fetchall()
|
||||||
@@ -154,6 +161,7 @@ class DatabaseManager:
|
|||||||
prompt = result[5]
|
prompt = result[5]
|
||||||
status = result[6]
|
status = result[6]
|
||||||
default_prompt = result[7]
|
default_prompt = result[7]
|
||||||
|
token_counts = result[8]
|
||||||
|
|
||||||
# 当且仅当最后一个该对象的会话是on_going状态时,才会被加载
|
# 当且仅当最后一个该对象的会话是on_going状态时,才会被加载
|
||||||
if status == 'on_going':
|
if status == 'on_going':
|
||||||
@@ -163,7 +171,8 @@ class DatabaseManager:
|
|||||||
'create_timestamp': create_timestamp,
|
'create_timestamp': create_timestamp,
|
||||||
'last_interact_timestamp': last_interact_timestamp,
|
'last_interact_timestamp': last_interact_timestamp,
|
||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
'default_prompt': default_prompt
|
'default_prompt': default_prompt,
|
||||||
|
'token_counts': token_counts
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
if session_name in sessions:
|
if session_name in sessions:
|
||||||
@@ -175,7 +184,7 @@ class DatabaseManager:
|
|||||||
def last_session(self, session_name: str, cursor_timestamp: int):
|
def last_session(self, session_name: str, cursor_timestamp: int):
|
||||||
|
|
||||||
self.__execute__("""
|
self.__execute__("""
|
||||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
|
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
|
||||||
from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc
|
from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc
|
||||||
limit 1
|
limit 1
|
||||||
""".format(session_name, cursor_timestamp))
|
""".format(session_name, cursor_timestamp))
|
||||||
@@ -192,6 +201,7 @@ class DatabaseManager:
|
|||||||
prompt = result[5]
|
prompt = result[5]
|
||||||
status = result[6]
|
status = result[6]
|
||||||
default_prompt = result[7]
|
default_prompt = result[7]
|
||||||
|
token_counts = result[8]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'subject_type': subject_type,
|
'subject_type': subject_type,
|
||||||
@@ -199,14 +209,15 @@ class DatabaseManager:
|
|||||||
'create_timestamp': create_timestamp,
|
'create_timestamp': create_timestamp,
|
||||||
'last_interact_timestamp': last_interact_timestamp,
|
'last_interact_timestamp': last_interact_timestamp,
|
||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
'default_prompt': default_prompt
|
'default_prompt': default_prompt,
|
||||||
|
'token_counts': token_counts
|
||||||
}
|
}
|
||||||
|
|
||||||
# 获取此session_name后一个session的数据
|
# 获取此session_name后一个session的数据
|
||||||
def next_session(self, session_name: str, cursor_timestamp: int):
|
def next_session(self, session_name: str, cursor_timestamp: int):
|
||||||
|
|
||||||
self.__execute__("""
|
self.__execute__("""
|
||||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
|
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
|
||||||
from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc
|
from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc
|
||||||
limit 1
|
limit 1
|
||||||
""".format(session_name, cursor_timestamp))
|
""".format(session_name, cursor_timestamp))
|
||||||
@@ -223,6 +234,7 @@ class DatabaseManager:
|
|||||||
prompt = result[5]
|
prompt = result[5]
|
||||||
status = result[6]
|
status = result[6]
|
||||||
default_prompt = result[7]
|
default_prompt = result[7]
|
||||||
|
token_counts = result[8]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'subject_type': subject_type,
|
'subject_type': subject_type,
|
||||||
@@ -230,13 +242,14 @@ class DatabaseManager:
|
|||||||
'create_timestamp': create_timestamp,
|
'create_timestamp': create_timestamp,
|
||||||
'last_interact_timestamp': last_interact_timestamp,
|
'last_interact_timestamp': last_interact_timestamp,
|
||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
'default_prompt': default_prompt
|
'default_prompt': default_prompt,
|
||||||
|
'token_counts': token_counts
|
||||||
}
|
}
|
||||||
|
|
||||||
# 列出与某个对象的所有对话session
|
# 列出与某个对象的所有对话session
|
||||||
def list_history(self, session_name: str, capacity: int, page: int):
|
def list_history(self, session_name: str, capacity: int, page: int):
|
||||||
self.__execute__("""
|
self.__execute__("""
|
||||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
|
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
|
||||||
from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {}
|
from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {}
|
||||||
""".format(session_name, capacity, capacity * page))
|
""".format(session_name, capacity, capacity * page))
|
||||||
results = self.cursor.fetchall()
|
results = self.cursor.fetchall()
|
||||||
@@ -250,6 +263,7 @@ class DatabaseManager:
|
|||||||
prompt = result[5]
|
prompt = result[5]
|
||||||
status = result[6]
|
status = result[6]
|
||||||
default_prompt = result[7]
|
default_prompt = result[7]
|
||||||
|
token_counts = result[8]
|
||||||
|
|
||||||
sessions.append({
|
sessions.append({
|
||||||
'subject_type': subject_type,
|
'subject_type': subject_type,
|
||||||
@@ -257,7 +271,8 @@ class DatabaseManager:
|
|||||||
'create_timestamp': create_timestamp,
|
'create_timestamp': create_timestamp,
|
||||||
'last_interact_timestamp': last_interact_timestamp,
|
'last_interact_timestamp': last_interact_timestamp,
|
||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
'default_prompt': default_prompt
|
'default_prompt': default_prompt,
|
||||||
|
'token_counts': token_counts
|
||||||
})
|
})
|
||||||
|
|
||||||
return sessions
|
return sessions
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ class OpenAIInteract:
|
|||||||
pkg.utils.context.set_openai_manager(self)
|
pkg.utils.context.set_openai_manager(self)
|
||||||
|
|
||||||
# 请求OpenAI Completion
|
# 请求OpenAI Completion
|
||||||
def request_completion(self, prompts) -> str:
|
def request_completion(self, prompts) -> tuple[str, int]:
|
||||||
"""请求补全接口回复
|
"""请求补全接口回复
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
@@ -60,14 +60,18 @@ class OpenAIInteract:
|
|||||||
|
|
||||||
logging.debug("OpenAI response: %s", response)
|
logging.debug("OpenAI response: %s", response)
|
||||||
|
|
||||||
|
# 记录使用量
|
||||||
|
current_round_token = 0
|
||||||
if 'model' in config.completion_api_params:
|
if 'model' in config.completion_api_params:
|
||||||
self.audit_mgr.report_text_model_usage(config.completion_api_params['model'],
|
self.audit_mgr.report_text_model_usage(config.completion_api_params['model'],
|
||||||
ai.get_total_tokens())
|
ai.get_total_tokens())
|
||||||
|
current_round_token = ai.get_total_tokens()
|
||||||
elif 'engine' in config.completion_api_params:
|
elif 'engine' in config.completion_api_params:
|
||||||
self.audit_mgr.report_text_model_usage(config.completion_api_params['engine'],
|
self.audit_mgr.report_text_model_usage(config.completion_api_params['engine'],
|
||||||
response['usage']['total_tokens'])
|
response['usage']['total_tokens'])
|
||||||
|
current_round_token = response['usage']['total_tokens']
|
||||||
|
|
||||||
return ai.get_message()
|
return ai.get_message(), current_round_token
|
||||||
|
|
||||||
def request_image(self, prompt) -> dict:
|
def request_image(self, prompt) -> dict:
|
||||||
"""请求图片接口回复
|
"""请求图片接口回复
|
||||||
|
|||||||
+70
-24
@@ -72,6 +72,7 @@ def load_sessions():
|
|||||||
temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp']
|
temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp']
|
||||||
try:
|
try:
|
||||||
temp_session.prompt = json.loads(session_data[session_name]['prompt'])
|
temp_session.prompt = json.loads(session_data[session_name]['prompt'])
|
||||||
|
temp_session.token_counts = json.loads(session_data[session_name]['token_counts'])
|
||||||
except Exception:
|
except Exception:
|
||||||
temp_session.prompt = reset_session_prompt(session_name, session_data[session_name]['prompt'])
|
temp_session.prompt = reset_session_prompt(session_name, session_data[session_name]['prompt'])
|
||||||
temp_session.persistence()
|
temp_session.persistence()
|
||||||
@@ -106,6 +107,9 @@ class Session:
|
|||||||
prompt = []
|
prompt = []
|
||||||
"""使用list来保存会话中的回合"""
|
"""使用list来保存会话中的回合"""
|
||||||
|
|
||||||
|
token_counts = []
|
||||||
|
"""每个回合的token数量"""
|
||||||
|
|
||||||
default_prompt = []
|
default_prompt = []
|
||||||
"""本session的默认prompt"""
|
"""本session的默认prompt"""
|
||||||
|
|
||||||
@@ -146,6 +150,8 @@ class Session:
|
|||||||
self.name = name
|
self.name = name
|
||||||
self.create_timestamp = int(time.time())
|
self.create_timestamp = int(time.time())
|
||||||
self.last_interact_timestamp = int(time.time())
|
self.last_interact_timestamp = int(time.time())
|
||||||
|
self.prompt = []
|
||||||
|
self.token_counts = []
|
||||||
self.schedule()
|
self.schedule()
|
||||||
|
|
||||||
self.response_lock = threading.Lock()
|
self.response_lock = threading.Lock()
|
||||||
@@ -209,9 +215,16 @@ class Session:
|
|||||||
config = pkg.utils.context.get_config()
|
config = pkg.utils.context.get_config()
|
||||||
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
|
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
|
||||||
|
|
||||||
|
prompts, counts = self.cut_out(text, max_length)
|
||||||
|
|
||||||
|
# 计算请求前的prompt数量
|
||||||
|
total_token_before_query = 0
|
||||||
|
for token_count in counts:
|
||||||
|
total_token_before_query += token_count
|
||||||
|
|
||||||
# 向API请求补全
|
# 向API请求补全
|
||||||
message = pkg.utils.context.get_openai_manager().request_completion(
|
message, total_token = pkg.utils.context.get_openai_manager().request_completion(
|
||||||
self.cut_out(text, max_length),
|
prompts,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 成功获取,处理回复
|
# 成功获取,处理回复
|
||||||
@@ -228,6 +241,10 @@ class Session:
|
|||||||
self.prompt.append({'role': 'user', 'content': text})
|
self.prompt.append({'role': 'user', 'content': text})
|
||||||
self.prompt.append({'role': 'assistant', 'content': res_ans})
|
self.prompt.append({'role': 'assistant', 'content': res_ans})
|
||||||
|
|
||||||
|
# 向token_counts中添加本回合的token数量
|
||||||
|
self.token_counts.append(total_token-total_token_before_query)
|
||||||
|
logging.debug("本回合使用token: {}, session counts: {}".format(total_token-total_token_before_query, self.token_counts))
|
||||||
|
|
||||||
if self.just_switched_to_exist_session:
|
if self.just_switched_to_exist_session:
|
||||||
self.just_switched_to_exist_session = False
|
self.just_switched_to_exist_session = False
|
||||||
self.set_ongoing()
|
self.set_ongoing()
|
||||||
@@ -244,39 +261,65 @@ class Session:
|
|||||||
|
|
||||||
question = self.prompt[-2]['content']
|
question = self.prompt[-2]['content']
|
||||||
self.prompt = self.prompt[:-2]
|
self.prompt = self.prompt[:-2]
|
||||||
|
self.token_counts = self.token_counts[:-1]
|
||||||
|
|
||||||
# 返回上一回合的问题
|
# 返回上一回合的问题
|
||||||
return question
|
return question
|
||||||
|
|
||||||
# 构建对话体
|
# 构建对话体
|
||||||
def cut_out(self, msg: str, max_tokens: int) -> list:
|
def cut_out(self, msg: str, max_tokens: int) -> tuple[list, list]:
|
||||||
"""将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens"""
|
"""将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens
|
||||||
# 如果用户消息长度超过max_tokens,直接返回
|
|
||||||
temp_prompt: list = []
|
:return: (新的prompt, 新的token_counts)
|
||||||
temp_prompt += self.default_prompt
|
"""
|
||||||
temp_prompt.append(
|
|
||||||
|
# 最终由三个部分组成
|
||||||
|
# - default_prompt 情景预设固定值
|
||||||
|
# - changable_prompts 可变部分, 此会话中的历史对话回合
|
||||||
|
# - current_question 当前问题
|
||||||
|
|
||||||
|
# 包装目前的对话回合内容
|
||||||
|
changable_prompts = []
|
||||||
|
changable_counts = []
|
||||||
|
# 倒着来, 遍历prompt的步长为2, 遍历tokens_counts的步长为1
|
||||||
|
changable_index = len(self.prompt) - 1
|
||||||
|
token_count_index = len(self.token_counts) - 1
|
||||||
|
|
||||||
|
packed_tokens = 0
|
||||||
|
|
||||||
|
print(self.prompt)
|
||||||
|
|
||||||
|
while changable_index >= 0 and token_count_index >= 0:
|
||||||
|
if packed_tokens + self.token_counts[token_count_index] > max_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
changable_prompts.insert(0, self.prompt[changable_index])
|
||||||
|
changable_prompts.insert(0, self.prompt[changable_index - 1])
|
||||||
|
changable_counts.insert(0, self.token_counts[token_count_index])
|
||||||
|
packed_tokens += self.token_counts[token_count_index]
|
||||||
|
|
||||||
|
changable_index -= 2
|
||||||
|
token_count_index -= 1
|
||||||
|
|
||||||
|
# 将default_prompt和changable_prompts合并
|
||||||
|
result_prompt = self.default_prompt + changable_prompts
|
||||||
|
|
||||||
|
print(changable_prompts)
|
||||||
|
|
||||||
|
# 添加当前问题
|
||||||
|
result_prompt.append(
|
||||||
{
|
{
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
'content': msg
|
'content': msg
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
token_count = 0
|
logging.debug('cut_out: {}\nchangable section tokens: {}\npacked counts: {}\nsession counts: {}'.format(json.dumps(result_prompt, ensure_ascii=False, indent=4),
|
||||||
for item in temp_prompt:
|
packed_tokens,
|
||||||
token_count += len(item['content'])
|
changable_counts,
|
||||||
|
self.token_counts))
|
||||||
|
|
||||||
# 倒序遍历prompt
|
return result_prompt, changable_counts
|
||||||
for i in range(len(self.prompt) - 1, -1, -1):
|
|
||||||
if token_count >= max_tokens:
|
|
||||||
break
|
|
||||||
|
|
||||||
# 将prompt加到temp_prompt倒数第二个位置
|
|
||||||
temp_prompt.insert(len(self.default_prompt), self.prompt[i])
|
|
||||||
token_count += len(self.prompt[i]['content'])
|
|
||||||
|
|
||||||
logging.debug('cut_out: {}'.format(json.dumps(temp_prompt, ensure_ascii=False, indent=4)))
|
|
||||||
|
|
||||||
return temp_prompt
|
|
||||||
|
|
||||||
# 持久化session
|
# 持久化session
|
||||||
def persistence(self):
|
def persistence(self):
|
||||||
@@ -291,7 +334,7 @@ class Session:
|
|||||||
subject_number = int(name_spt[1])
|
subject_number = int(name_spt[1])
|
||||||
|
|
||||||
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
|
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
|
||||||
json.dumps(self.prompt), json.dumps(self.default_prompt))
|
json.dumps(self.prompt), json.dumps(self.default_prompt), json.dumps(self.token_counts))
|
||||||
|
|
||||||
# 重置session
|
# 重置session
|
||||||
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None):
|
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None):
|
||||||
@@ -314,6 +357,7 @@ class Session:
|
|||||||
|
|
||||||
self.default_prompt = self.get_default_prompt(use_prompt)
|
self.default_prompt = self.get_default_prompt(use_prompt)
|
||||||
self.prompt = []
|
self.prompt = []
|
||||||
|
self.token_counts = []
|
||||||
self.create_timestamp = int(time.time())
|
self.create_timestamp = int(time.time())
|
||||||
self.last_interact_timestamp = int(time.time())
|
self.last_interact_timestamp = int(time.time())
|
||||||
self.just_switched_to_exist_session = False
|
self.just_switched_to_exist_session = False
|
||||||
@@ -339,6 +383,7 @@ class Session:
|
|||||||
self.last_interact_timestamp = last_one['last_interact_timestamp']
|
self.last_interact_timestamp = last_one['last_interact_timestamp']
|
||||||
try:
|
try:
|
||||||
self.prompt = json.loads(last_one['prompt'])
|
self.prompt = json.loads(last_one['prompt'])
|
||||||
|
self.token_counts = json.loads(last_one['token_counts'])
|
||||||
except json.decoder.JSONDecodeError:
|
except json.decoder.JSONDecodeError:
|
||||||
self.prompt = reset_session_prompt(self.name, last_one['prompt'])
|
self.prompt = reset_session_prompt(self.name, last_one['prompt'])
|
||||||
self.persistence()
|
self.persistence()
|
||||||
@@ -359,6 +404,7 @@ class Session:
|
|||||||
self.last_interact_timestamp = next_one['last_interact_timestamp']
|
self.last_interact_timestamp = next_one['last_interact_timestamp']
|
||||||
try:
|
try:
|
||||||
self.prompt = json.loads(next_one['prompt'])
|
self.prompt = json.loads(next_one['prompt'])
|
||||||
|
self.token_counts = json.loads(next_one['token_counts'])
|
||||||
except json.decoder.JSONDecodeError:
|
except json.decoder.JSONDecodeError:
|
||||||
self.prompt = reset_session_prompt(self.name, next_one['prompt'])
|
self.prompt = reset_session_prompt(self.name, next_one['prompt'])
|
||||||
self.persistence()
|
self.persistence()
|
||||||
|
|||||||
Reference in New Issue
Block a user