From 18ae2299a738ec66c14b0ab51312581afb497a94 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sat, 18 May 2024 20:08:48 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20query.resp=5Fmessag?= =?UTF-8?q?es=20=E5=AF=B9=E6=8F=92=E4=BB=B6reply=E7=9A=84=E5=85=BC?= =?UTF-8?q?=E5=AE=B9=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/core/entities.py | 2 +- pkg/pipeline/cntfilter/cntfilter.py | 4 +- pkg/pipeline/process/handlers/chat.py | 7 +- pkg/pipeline/process/handlers/command.py | 7 +- pkg/pipeline/wrapper/wrapper.py | 137 ++++++++++++----------- pkg/provider/entities.py | 11 +- 6 files changed, 89 insertions(+), 79 deletions(-) diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 30b983ad..7f7088c4 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -67,7 +67,7 @@ class Query(pydantic.BaseModel): use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None """使用的函数,由前置处理器阶段设置""" - resp_messages: typing.Optional[list[llm_entities.Message]] = [] + resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[mirai.MessageChain]] = [] """由Process阶段生成的回复消息对象列表""" resp_message_chain: typing.Optional[list[mirai.MessageChain]] = None diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index a669e310..a982a550 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -163,13 +163,13 @@ class ContentFilterStage(stage.PipelineStage): ) elif stage_inst_name == 'PostContentFilterStage': # 仅处理 query.resp_messages[-1].content 是 str 的情况 - if isinstance(query.resp_messages[-1].content, str): + if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(query.resp_messages[-1].content, str): return await self._post_process( query.resp_messages[-1].content, query ) else: - self.ap.logger.debug(f"resp_messages[-1] 不是 str 类型,跳过内容过滤器检查。") + self.ap.logger.debug(f"resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。") return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 26f73b61..deadb44e 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -42,12 +42,7 @@ class ChatMessageHandler(handler.MessageHandler): if event_ctx.event.reply is not None: mc = mirai.MessageChain(event_ctx.event.reply) - query.resp_messages.append( - llm_entities.Message( - role='plugin', - content=mc, - ) - ) + query.resp_messages.append(mc) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py index 02ff269a..8c8fb8ba 100644 --- a/pkg/pipeline/process/handlers/command.py +++ b/pkg/pipeline/process/handlers/command.py @@ -48,12 +48,7 @@ class CommandHandler(handler.MessageHandler): if event_ctx.event.reply is not None: mc = mirai.MessageChain(event_ctx.event.reply) - query.resp_messages.append( - llm_entities.Message( - role='command', - content=str(mc), - ) - ) + query.resp_messages.append(mc) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py index acf0549d..3e52beb3 100644 --- a/pkg/pipeline/wrapper/wrapper.py +++ b/pkg/pipeline/wrapper/wrapper.py @@ -32,80 +32,49 @@ class ResponseWrapper(stage.PipelineStage): ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: """处理 """ - - if query.resp_messages[-1].role == 'command': - # query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content)) - query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] ')) + + # 如果 resp_messages[-1] 已经是 MessageChain 了 + if isinstance(query.resp_messages[-1], mirai.MessageChain): + query.resp_message_chain.append(query.resp_messages[-1]) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query ) - elif query.resp_messages[-1].role == 'plugin': - # if not isinstance(query.resp_messages[-1].content, mirai.MessageChain): - # query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content)) - # else: - # query.resp_message_chain.append(query.resp_messages[-1].content) - query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain()) - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query - ) else: + + if query.resp_messages[-1].role == 'command': + # query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content)) + query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] ')) - if query.resp_messages[-1].role == 'assistant': - result = query.resp_messages[-1] - session = await self.ap.sess_mgr.get_session(query) + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + elif query.resp_messages[-1].role == 'plugin': + # if not isinstance(query.resp_messages[-1].content, mirai.MessageChain): + # query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content)) + # else: + # query.resp_message_chain.append(query.resp_messages[-1].content) + query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain()) - reply_text = '' + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + else: - if result.content is not None: # 有内容 - reply_text = str(result.get_content_mirai_message_chain()) + if query.resp_messages[-1].role == 'assistant': + result = query.resp_messages[-1] + session = await self.ap.sess_mgr.get_session(query) - # ============= 触发插件事件 =============== - event_ctx = await self.ap.plugin_mgr.emit_event( - event=events.NormalMessageResponded( - launcher_type=query.launcher_type.value, - launcher_id=query.launcher_id, - sender_id=query.sender_id, - session=session, - prefix='', - response_text=reply_text, - finish_reason='stop', - funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [], - query=query - ) - ) - if event_ctx.is_prevented_default(): - yield entities.StageProcessResult( - result_type=entities.ResultType.INTERRUPT, - new_query=query - ) - else: - if event_ctx.event.reply is not None: - - query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply)) + reply_text = '' - else: + if result.content is not None: # 有内容 + reply_text = str(result.get_content_mirai_message_chain()) - query.resp_message_chain.append(result.get_content_mirai_message_chain()) - - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query - ) - - if result.tool_calls is not None: # 有函数调用 - - function_names = [tc.function.name for tc in result.tool_calls] - - reply_text = f'调用函数 {".".join(function_names)}...' - - query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)])) - - if self.ap.platform_cfg.data['track-function-calls']: - + # ============= 触发插件事件 =============== event_ctx = await self.ap.plugin_mgr.emit_event( event=events.NormalMessageResponded( launcher_type=query.launcher_type.value, @@ -119,7 +88,6 @@ class ResponseWrapper(stage.PipelineStage): query=query ) ) - if event_ctx.is_prevented_default(): yield entities.StageProcessResult( result_type=entities.ResultType.INTERRUPT, @@ -132,9 +100,52 @@ class ResponseWrapper(stage.PipelineStage): else: - query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)])) + query.resp_message_chain.append(result.get_content_mirai_message_chain()) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query ) + + if result.tool_calls is not None: # 有函数调用 + + function_names = [tc.function.name for tc in result.tool_calls] + + reply_text = f'调用函数 {".".join(function_names)}...' + + query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)])) + + if self.ap.platform_cfg.data['track-function-calls']: + + event_ctx = await self.ap.plugin_mgr.emit_event( + event=events.NormalMessageResponded( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + session=session, + prefix='', + response_text=reply_text, + finish_reason='stop', + funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [], + query=query + ) + ) + + if event_ctx.is_prevented_default(): + yield entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query + ) + else: + if event_ctx.event.reply is not None: + + query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply)) + + else: + + query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)])) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index 3b87f5cb..82c68adb 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -96,7 +96,16 @@ class Message(pydantic.BaseModel): if ce.type == 'text': mc.append(mirai.Plain(ce.text)) elif ce.type == 'image': - mc.append(mirai.Image(url=ce.image_url)) + if ce.image_url.url.startswith("http"): + mc.append(mirai.Image(url=ce.image_url.url)) + else: # base64 + + b64_str = ce.image_url.url + + if b64_str.startswith("data:"): + b64_str = b64_str.split(",")[1] + + mc.append(mirai.Image(base64=b64_str)) # 找第一个文字组件 if prefix_text: