From a9d92115f84692d207d7dfad05d544526d8dcbfc Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Thu, 1 Feb 2024 17:42:51 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20chat=E5=89=8D=E7=9A=84=E5=89=8D?= =?UTF-8?q?=E6=96=87=E5=89=AA=E8=A3=81=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/core/controller.py | 3 ++- pkg/pipeline/preproc/preproc.py | 22 ++++++++++++++++++- pkg/pipeline/process/handlers/chat.py | 20 ++++++++--------- pkg/provider/requester/tokenizers/tiktoken.py | 2 +- 4 files changed, 33 insertions(+), 14 deletions(-) diff --git a/pkg/core/controller.py b/pkg/core/controller.py index 4c135208..236dcde2 100644 --- a/pkg/core/controller.py +++ b/pkg/core/controller.py @@ -149,7 +149,8 @@ class Controller: await self._execute_from_stage(0, query) except Exception as e: self.ap.logger.error(f"处理请求时出错 {query}: {e}") - self.ap.logger.debug(f"处理请求时出错 {query}: {e}", exc_info=True) + # self.ap.logger.debug(f"处理请求时出错 {query}: {e}", exc_info=True) + traceback.print_exc() finally: self.ap.logger.debug(f"Query {query} processed") diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index 578c7433..1df5daf9 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -52,8 +52,28 @@ class PreProcessor(stage.PipelineStage): query.messages = event_ctx.event.prompt # 根据模型max_tokens剪裁 + max_tokens = min(query.use_model.max_tokens, self.ap.cfg_mgr.data['prompt_submit_length']) + + test_messages = query.prompt.messages + query.messages + [query.user_message] + + while await query.use_model.tokenizer.count_token(test_messages, query.use_model) > max_tokens: + # 前文都pop完了,还是大于max_tokens,由于prompt和user_messages不能删减,报错 + if len(query.prompt.messages) == 0: + return entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query, + user_notice='输入内容过长,请减少情景预设或者输入内容长度', + console_notice='输入内容过长,请减少情景预设或者输入内容长度,或者增大配置文件中的 prompt_submit_length 项(但不能超过所用模型最大tokens数)' + ) + + query.messages.pop(0) # pop第一个肯定是role=user的 + # 继续pop到第二个role=user前一个 + while len(query.messages) > 0 and query.messages[0].role != 'user': + query.messages.pop(0) + + test_messages = query.prompt.messages + query.messages + [query.user_message] return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query - ) \ No newline at end of file + ) diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index b57de98f..3e7673b1 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -82,14 +82,12 @@ class ChatMessageHandler(handler.MessageHandler): query.session.using_conversation.messages.append(query.user_message) query.session.using_conversation.messages.extend(query.resp_messages) - print(query.session.using_conversation.messages) - - await self.ap.ctr_mgr.usage.post_query_record( - session_type=query.session.launcher_type.value, - session_id=str(query.session.launcher_id), - query_ability_provider="QChatGPT.Chat", - usage=text_length, - model_name=query.use_model.name, - response_seconds=int(time.time() - start_time), - retry_times=-1, - ) \ No newline at end of file + await self.ap.ctr_mgr.usage.post_query_record( + session_type=query.session.launcher_type.value, + session_id=str(query.session.launcher_id), + query_ability_provider="QChatGPT.Chat", + usage=text_length, + model_name=query.use_model.name, + response_seconds=int(time.time() - start_time), + retry_times=-1, + ) \ No newline at end of file diff --git a/pkg/provider/requester/tokenizers/tiktoken.py b/pkg/provider/requester/tokenizers/tiktoken.py index 3b83a144..14a456c0 100644 --- a/pkg/provider/requester/tokenizers/tiktoken.py +++ b/pkg/provider/requester/tokenizers/tiktoken.py @@ -23,6 +23,6 @@ class Tiktoken(tokenizer.LLMTokenizer): num_tokens = 0 for message in messages: num_tokens += len(encoding.encode(message.role)) - num_tokens += len(encoding.encode(message.content)) + num_tokens += len(encoding.encode(message.content if message.content is not None else '')) num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> return num_tokens