feat: chat前的前文剪裁逻辑

This commit is contained in:
RockChinQ
2024-02-01 17:42:51 +08:00
parent 6f2d7d96d0
commit a9d92115f8
4 changed files with 33 additions and 14 deletions

View File

@@ -149,7 +149,8 @@ class Controller:
await self._execute_from_stage(0, query)
except Exception as e:
self.ap.logger.error(f"处理请求时出错 {query}: {e}")
self.ap.logger.debug(f"处理请求时出错 {query}: {e}", exc_info=True)
# self.ap.logger.debug(f"处理请求时出错 {query}: {e}", exc_info=True)
traceback.print_exc()
finally:
self.ap.logger.debug(f"Query {query} processed")

View File

@@ -52,8 +52,28 @@ class PreProcessor(stage.PipelineStage):
query.messages = event_ctx.event.prompt
# 根据模型max_tokens剪裁
max_tokens = min(query.use_model.max_tokens, self.ap.cfg_mgr.data['prompt_submit_length'])
test_messages = query.prompt.messages + query.messages + [query.user_message]
while await query.use_model.tokenizer.count_token(test_messages, query.use_model) > max_tokens:
# 前文都pop完了还是大于max_tokens由于prompt和user_messages不能删减报错
if len(query.prompt.messages) == 0:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice='输入内容过长,请减少情景预设或者输入内容长度',
console_notice='输入内容过长,请减少情景预设或者输入内容长度,或者增大配置文件中的 prompt_submit_length 项但不能超过所用模型最大tokens数'
)
query.messages.pop(0) # pop第一个肯定是role=user的
# 继续pop到第二个role=user前一个
while len(query.messages) > 0 and query.messages[0].role != 'user':
query.messages.pop(0)
test_messages = query.prompt.messages + query.messages + [query.user_message]
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
)

View File

@@ -82,14 +82,12 @@ class ChatMessageHandler(handler.MessageHandler):
query.session.using_conversation.messages.append(query.user_message)
query.session.using_conversation.messages.extend(query.resp_messages)
print(query.session.using_conversation.messages)
await self.ap.ctr_mgr.usage.post_query_record(
session_type=query.session.launcher_type.value,
session_id=str(query.session.launcher_id),
query_ability_provider="QChatGPT.Chat",
usage=text_length,
model_name=query.use_model.name,
response_seconds=int(time.time() - start_time),
retry_times=-1,
)
await self.ap.ctr_mgr.usage.post_query_record(
session_type=query.session.launcher_type.value,
session_id=str(query.session.launcher_id),
query_ability_provider="QChatGPT.Chat",
usage=text_length,
model_name=query.use_model.name,
response_seconds=int(time.time() - start_time),
retry_times=-1,
)

View File

@@ -23,6 +23,6 @@ class Tiktoken(tokenizer.LLMTokenizer):
num_tokens = 0
for message in messages:
num_tokens += len(encoding.encode(message.role))
num_tokens += len(encoding.encode(message.content))
num_tokens += len(encoding.encode(message.content if message.content is not None else ''))
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens