Merge pull request #300 from RockChinQ/token-process

[Perf] Tokens相关处理逻辑优化
This commit is contained in:
Rock Chin
2023-03-19 16:35:25 +08:00
committed by GitHub
6 changed files with 132 additions and 44 deletions
+25 -2
View File
@@ -82,7 +82,30 @@ default_prompt = {
# 情景预设格式 # 情景预设格式
# 参考值:旧版本方式:default | 完整情景:full_scenario # 参考值:旧版本方式:default | 完整情景:full_scenario
# 旧版本的格式为上述default_prompt中的内容,或prompts目录下的文件名 # 旧版本的格式为上述default_prompt中的内容,或prompts目录下的文件名
# 完整情景预设的格式为JSON,在JSON文件中列出对话的每个回合,编写方法见scenario/default-template.json #
# 完整情景预设的格式为JSON,在scenario目录下的JSON文件中列出对话的每个回合,编写方法见scenario/default-template.json
# 编写方法例如:
# {
# "prompt": [
# {
# "role": "user",
# "content": "之后当我需要帮助时,请说“输入!help获取帮助”"
# },{
# "role": "assistant",
# "content": "好的,当你之后需要帮助时,我会说“输入!help获取帮助”"
# },{
# "role": "user",
# "content": "帮助"
# },{
# "role": "assistant",
# "content": "输入!help获取帮助"
# }
# ]
# }
#
# 您可以按照上述格式编写自己的情景预设,在prompt中列出对话的每个回合,
# role为user或assistant,分别表示用户和机器人的回复
# 每个JSON文件是一个情景预设,文件名即为情景预设的名称
preset_mode = "default" preset_mode = "default"
# 群内响应规则 # 群内响应规则
@@ -139,7 +162,7 @@ encourage_sponsor_at_start = True
# 每次向OpenAI接口发送对话记录上下文的字符数 # 每次向OpenAI接口发送对话记录上下文的字符数
# 最大不超过(4096 - max_tokens)个字符,max_tokens为下方completion_api_params中的max_tokens # 最大不超过(4096 - max_tokens)个字符,max_tokens为下方completion_api_params中的max_tokens
# 注意:较大的prompt_submit_length会导致OpenAI账户额度消耗更快 # 注意:较大的prompt_submit_length会导致OpenAI账户额度消耗更快
prompt_submit_length = 1024 prompt_submit_length = 2048
# OpenAI补全API的参数 # OpenAI补全API的参数
# 请在下方填写模型,程序自动选择接口 # 请在下方填写模型,程序自动选择接口
+31 -16
View File
@@ -54,20 +54,27 @@ class DatabaseManager:
`last_interact_timestamp` bigint not null, `last_interact_timestamp` bigint not null,
`status` varchar(255) not null default 'on_going', `status` varchar(255) not null default 'on_going',
`default_prompt` text not null default '', `default_prompt` text not null default '',
`prompt` text not null `prompt` text not null,
`token_counts` text not null default '[]'
) )
""") """)
# 检查sessions表是否存在`default_prompt`字段 # 检查sessions表是否存在`default_prompt`字段, 检查是否存在`token_counts`字段
self.__execute__("PRAGMA table_info('sessions')") self.__execute__("PRAGMA table_info('sessions')")
columns = self.cursor.fetchall() columns = self.cursor.fetchall()
has_default_prompt = False has_default_prompt = False
has_token_counts = False
for field in columns: for field in columns:
if field[1] == 'default_prompt': if field[1] == 'default_prompt':
has_default_prompt = True has_default_prompt = True
if field[1] == 'token_counts':
has_token_counts = True
if has_default_prompt and has_token_counts:
break break
if not has_default_prompt: if not has_default_prompt:
self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''") self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''")
if not has_token_counts:
self.__execute__("alter table `sessions` add column `token_counts` text not null default '[]'")
self.__execute__(""" self.__execute__("""
@@ -89,7 +96,7 @@ class DatabaseManager:
# session持久化 # session持久化
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int, def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
last_interact_timestamp: int, prompt: str, default_prompt: str = ''): last_interact_timestamp: int, prompt: str, default_prompt: str = '', token_counts: str = ''):
"""持久化指定session""" """持久化指定session"""
# 检查是否已经有了此name和create_timestamp的session # 检查是否已经有了此name和create_timestamp的session
@@ -102,20 +109,20 @@ class DatabaseManager:
if count == 0: if count == 0:
sql = """ sql = """
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`) insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`, `token_counts`)
values (?, ?, ?, ?, ?, ?, ?) values (?, ?, ?, ?, ?, ?, ?, ?)
""" """
self.__execute__(sql, self.__execute__(sql,
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp, ("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
last_interact_timestamp, prompt, default_prompt)) last_interact_timestamp, prompt, default_prompt, token_counts))
else: else:
sql = """ sql = """
update `sessions` set `last_interact_timestamp` = ?, `prompt` = ? update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?, `token_counts` = ?
where `type` = ? and `number` = ? and `create_timestamp` = ? where `type` = ? and `number` = ? and `create_timestamp` = ?
""" """
self.__execute__(sql, (last_interact_timestamp, prompt, subject_type, self.__execute__(sql, (last_interact_timestamp, prompt, token_counts, subject_type,
subject_number, create_timestamp)) subject_number, create_timestamp))
# 显式关闭一个session # 显式关闭一个session
@@ -140,7 +147,7 @@ class DatabaseManager:
# 从数据库中加载所有还没过期的session # 从数据库中加载所有还没过期的session
config = pkg.utils.context.get_config() config = pkg.utils.context.get_config()
self.__execute__(""" self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `last_interact_timestamp` > {} from `sessions` where `last_interact_timestamp` > {}
""".format(int(time.time()) - config.session_expire_time)) """.format(int(time.time()) - config.session_expire_time))
results = self.cursor.fetchall() results = self.cursor.fetchall()
@@ -154,6 +161,7 @@ class DatabaseManager:
prompt = result[5] prompt = result[5]
status = result[6] status = result[6]
default_prompt = result[7] default_prompt = result[7]
token_counts = result[8]
# 当且仅当最后一个该对象的会话是on_going状态时,才会被加载 # 当且仅当最后一个该对象的会话是on_going状态时,才会被加载
if status == 'on_going': if status == 'on_going':
@@ -163,7 +171,8 @@ class DatabaseManager:
'create_timestamp': create_timestamp, 'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp, 'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt, 'prompt': prompt,
'default_prompt': default_prompt 'default_prompt': default_prompt,
'token_counts': token_counts
} }
else: else:
if session_name in sessions: if session_name in sessions:
@@ -175,7 +184,7 @@ class DatabaseManager:
def last_session(self, session_name: str, cursor_timestamp: int): def last_session(self, session_name: str, cursor_timestamp: int):
self.__execute__(""" self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc
limit 1 limit 1
""".format(session_name, cursor_timestamp)) """.format(session_name, cursor_timestamp))
@@ -192,6 +201,7 @@ class DatabaseManager:
prompt = result[5] prompt = result[5]
status = result[6] status = result[6]
default_prompt = result[7] default_prompt = result[7]
token_counts = result[8]
return { return {
'subject_type': subject_type, 'subject_type': subject_type,
@@ -199,14 +209,15 @@ class DatabaseManager:
'create_timestamp': create_timestamp, 'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp, 'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt, 'prompt': prompt,
'default_prompt': default_prompt 'default_prompt': default_prompt,
'token_counts': token_counts
} }
# 获取此session_name后一个session的数据 # 获取此session_name后一个session的数据
def next_session(self, session_name: str, cursor_timestamp: int): def next_session(self, session_name: str, cursor_timestamp: int):
self.__execute__(""" self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc
limit 1 limit 1
""".format(session_name, cursor_timestamp)) """.format(session_name, cursor_timestamp))
@@ -223,6 +234,7 @@ class DatabaseManager:
prompt = result[5] prompt = result[5]
status = result[6] status = result[6]
default_prompt = result[7] default_prompt = result[7]
token_counts = result[8]
return { return {
'subject_type': subject_type, 'subject_type': subject_type,
@@ -230,13 +242,14 @@ class DatabaseManager:
'create_timestamp': create_timestamp, 'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp, 'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt, 'prompt': prompt,
'default_prompt': default_prompt 'default_prompt': default_prompt,
'token_counts': token_counts
} }
# 列出与某个对象的所有对话session # 列出与某个对象的所有对话session
def list_history(self, session_name: str, capacity: int, page: int): def list_history(self, session_name: str, capacity: int, page: int):
self.__execute__(""" self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {} from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {}
""".format(session_name, capacity, capacity * page)) """.format(session_name, capacity, capacity * page))
results = self.cursor.fetchall() results = self.cursor.fetchall()
@@ -250,6 +263,7 @@ class DatabaseManager:
prompt = result[5] prompt = result[5]
status = result[6] status = result[6]
default_prompt = result[7] default_prompt = result[7]
token_counts = result[8]
sessions.append({ sessions.append({
'subject_type': subject_type, 'subject_type': subject_type,
@@ -257,7 +271,8 @@ class DatabaseManager:
'create_timestamp': create_timestamp, 'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp, 'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt, 'prompt': prompt,
'default_prompt': default_prompt 'default_prompt': default_prompt,
'token_counts': token_counts
}) })
return sessions return sessions
+6 -2
View File
@@ -34,7 +34,7 @@ class OpenAIInteract:
pkg.utils.context.set_openai_manager(self) pkg.utils.context.set_openai_manager(self)
# 请求OpenAI Completion # 请求OpenAI Completion
def request_completion(self, prompts) -> str: def request_completion(self, prompts) -> tuple[str, int]:
"""请求补全接口回复 """请求补全接口回复
Parameters: Parameters:
@@ -60,14 +60,18 @@ class OpenAIInteract:
logging.debug("OpenAI response: %s", response) logging.debug("OpenAI response: %s", response)
# 记录使用量
current_round_token = 0
if 'model' in config.completion_api_params: if 'model' in config.completion_api_params:
self.audit_mgr.report_text_model_usage(config.completion_api_params['model'], self.audit_mgr.report_text_model_usage(config.completion_api_params['model'],
ai.get_total_tokens()) ai.get_total_tokens())
current_round_token = ai.get_total_tokens()
elif 'engine' in config.completion_api_params: elif 'engine' in config.completion_api_params:
self.audit_mgr.report_text_model_usage(config.completion_api_params['engine'], self.audit_mgr.report_text_model_usage(config.completion_api_params['engine'],
response['usage']['total_tokens']) response['usage']['total_tokens'])
current_round_token = response['usage']['total_tokens']
return ai.get_message() return ai.get_message(), current_round_token
def request_image(self, prompt) -> dict: def request_image(self, prompt) -> dict:
"""请求图片接口回复 """请求图片接口回复
+70 -24
View File
@@ -72,6 +72,7 @@ def load_sessions():
temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp'] temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp']
try: try:
temp_session.prompt = json.loads(session_data[session_name]['prompt']) temp_session.prompt = json.loads(session_data[session_name]['prompt'])
temp_session.token_counts = json.loads(session_data[session_name]['token_counts'])
except Exception: except Exception:
temp_session.prompt = reset_session_prompt(session_name, session_data[session_name]['prompt']) temp_session.prompt = reset_session_prompt(session_name, session_data[session_name]['prompt'])
temp_session.persistence() temp_session.persistence()
@@ -106,6 +107,9 @@ class Session:
prompt = [] prompt = []
"""使用list来保存会话中的回合""" """使用list来保存会话中的回合"""
token_counts = []
"""每个回合的token数量"""
default_prompt = [] default_prompt = []
"""本session的默认prompt""" """本session的默认prompt"""
@@ -146,6 +150,8 @@ class Session:
self.name = name self.name = name
self.create_timestamp = int(time.time()) self.create_timestamp = int(time.time())
self.last_interact_timestamp = int(time.time()) self.last_interact_timestamp = int(time.time())
self.prompt = []
self.token_counts = []
self.schedule() self.schedule()
self.response_lock = threading.Lock() self.response_lock = threading.Lock()
@@ -209,9 +215,16 @@ class Session:
config = pkg.utils.context.get_config() config = pkg.utils.context.get_config()
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024 max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
prompts, counts = self.cut_out(text, max_length)
# 计算请求前的prompt数量
total_token_before_query = 0
for token_count in counts:
total_token_before_query += token_count
# 向API请求补全 # 向API请求补全
message = pkg.utils.context.get_openai_manager().request_completion( message, total_token = pkg.utils.context.get_openai_manager().request_completion(
self.cut_out(text, max_length), prompts,
) )
# 成功获取,处理回复 # 成功获取,处理回复
@@ -228,6 +241,10 @@ class Session:
self.prompt.append({'role': 'user', 'content': text}) self.prompt.append({'role': 'user', 'content': text})
self.prompt.append({'role': 'assistant', 'content': res_ans}) self.prompt.append({'role': 'assistant', 'content': res_ans})
# 向token_counts中添加本回合的token数量
self.token_counts.append(total_token-total_token_before_query)
logging.debug("本回合使用token: {}, session counts: {}".format(total_token-total_token_before_query, self.token_counts))
if self.just_switched_to_exist_session: if self.just_switched_to_exist_session:
self.just_switched_to_exist_session = False self.just_switched_to_exist_session = False
self.set_ongoing() self.set_ongoing()
@@ -244,39 +261,65 @@ class Session:
question = self.prompt[-2]['content'] question = self.prompt[-2]['content']
self.prompt = self.prompt[:-2] self.prompt = self.prompt[:-2]
self.token_counts = self.token_counts[:-1]
# 返回上一回合的问题 # 返回上一回合的问题
return question return question
# 构建对话体 # 构建对话体
def cut_out(self, msg: str, max_tokens: int) -> list: def cut_out(self, msg: str, max_tokens: int) -> tuple[list, list]:
"""将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens""" """将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens
# 如果用户消息长度超过max_tokens,直接返回
temp_prompt: list = [] :return: (新的prompt, 新的token_counts)
temp_prompt += self.default_prompt """
temp_prompt.append(
# 最终由三个部分组成
# - default_prompt 情景预设固定值
# - changable_prompts 可变部分, 此会话中的历史对话回合
# - current_question 当前问题
# 包装目前的对话回合内容
changable_prompts = []
changable_counts = []
# 倒着来, 遍历prompt的步长为2, 遍历tokens_counts的步长为1
changable_index = len(self.prompt) - 1
token_count_index = len(self.token_counts) - 1
packed_tokens = 0
print(self.prompt)
while changable_index >= 0 and token_count_index >= 0:
if packed_tokens + self.token_counts[token_count_index] > max_tokens:
break
changable_prompts.insert(0, self.prompt[changable_index])
changable_prompts.insert(0, self.prompt[changable_index - 1])
changable_counts.insert(0, self.token_counts[token_count_index])
packed_tokens += self.token_counts[token_count_index]
changable_index -= 2
token_count_index -= 1
# 将default_prompt和changable_prompts合并
result_prompt = self.default_prompt + changable_prompts
print(changable_prompts)
# 添加当前问题
result_prompt.append(
{ {
'role': 'user', 'role': 'user',
'content': msg 'content': msg
} }
) )
token_count = 0 logging.debug('cut_out: {}\nchangable section tokens: {}\npacked counts: {}\nsession counts: {}'.format(json.dumps(result_prompt, ensure_ascii=False, indent=4),
for item in temp_prompt: packed_tokens,
token_count += len(item['content']) changable_counts,
self.token_counts))
# 倒序遍历prompt return result_prompt, changable_counts
for i in range(len(self.prompt) - 1, -1, -1):
if token_count >= max_tokens:
break
# 将prompt加到temp_prompt倒数第二个位置
temp_prompt.insert(len(self.default_prompt), self.prompt[i])
token_count += len(self.prompt[i]['content'])
logging.debug('cut_out: {}'.format(json.dumps(temp_prompt, ensure_ascii=False, indent=4)))
return temp_prompt
# 持久化session # 持久化session
def persistence(self): def persistence(self):
@@ -291,7 +334,7 @@ class Session:
subject_number = int(name_spt[1]) subject_number = int(name_spt[1])
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp, db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
json.dumps(self.prompt), json.dumps(self.default_prompt)) json.dumps(self.prompt), json.dumps(self.default_prompt), json.dumps(self.token_counts))
# 重置session # 重置session
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None): def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None):
@@ -314,6 +357,7 @@ class Session:
self.default_prompt = self.get_default_prompt(use_prompt) self.default_prompt = self.get_default_prompt(use_prompt)
self.prompt = [] self.prompt = []
self.token_counts = []
self.create_timestamp = int(time.time()) self.create_timestamp = int(time.time())
self.last_interact_timestamp = int(time.time()) self.last_interact_timestamp = int(time.time())
self.just_switched_to_exist_session = False self.just_switched_to_exist_session = False
@@ -339,6 +383,7 @@ class Session:
self.last_interact_timestamp = last_one['last_interact_timestamp'] self.last_interact_timestamp = last_one['last_interact_timestamp']
try: try:
self.prompt = json.loads(last_one['prompt']) self.prompt = json.loads(last_one['prompt'])
self.token_counts = json.loads(last_one['token_counts'])
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
self.prompt = reset_session_prompt(self.name, last_one['prompt']) self.prompt = reset_session_prompt(self.name, last_one['prompt'])
self.persistence() self.persistence()
@@ -359,6 +404,7 @@ class Session:
self.last_interact_timestamp = next_one['last_interact_timestamp'] self.last_interact_timestamp = next_one['last_interact_timestamp']
try: try:
self.prompt = json.loads(next_one['prompt']) self.prompt = json.loads(next_one['prompt'])
self.token_counts = json.loads(next_one['token_counts'])
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
self.prompt = reset_session_prompt(self.name, next_one['prompt']) self.prompt = reset_session_prompt(self.name, next_one['prompt'])
self.persistence() self.persistence()
View File
View File