fix: 修复 query.resp_messages 对插件reply的兼容性

This commit is contained in:
RockChinQ
2024-05-18 20:08:48 +08:00
parent 7463e0aab9
commit 18ae2299a7
6 changed files with 89 additions and 79 deletions
+74 -63
View File
@@ -32,80 +32,49 @@ class ResponseWrapper(stage.PipelineStage):
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
if query.resp_messages[-1].role == 'command':
# query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content))
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] '))
# 如果 resp_messages[-1] 已经是 MessageChain 了
if isinstance(query.resp_messages[-1], mirai.MessageChain):
query.resp_message_chain.append(query.resp_messages[-1])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
elif query.resp_messages[-1].role == 'plugin':
# if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
# query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content))
# else:
# query.resp_message_chain.append(query.resp_messages[-1].content)
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain())
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
if query.resp_messages[-1].role == 'command':
# query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content))
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] '))
if query.resp_messages[-1].role == 'assistant':
result = query.resp_messages[-1]
session = await self.ap.sess_mgr.get_session(query)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
elif query.resp_messages[-1].role == 'plugin':
# if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
# query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content))
# else:
# query.resp_message_chain.append(query.resp_messages[-1].content)
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain())
reply_text = ''
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
if result.content is not None: # 有内容
reply_text = str(result.get_content_mirai_message_chain())
if query.resp_messages[-1].role == 'assistant':
result = query.resp_messages[-1]
session = await self.ap.sess_mgr.get_session(query)
# ============= 触发插件事件 ===============
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
reply_text = ''
else:
if result.content is not None: # 有内容
reply_text = str(result.get_content_mirai_message_chain())
query.resp_message_chain.append(result.get_content_mirai_message_chain())
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
if result.tool_calls is not None: # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
if self.ap.platform_cfg.data['track-function-calls']:
# ============= 触发插件事件 ===============
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
@@ -119,7 +88,6 @@ class ResponseWrapper(stage.PipelineStage):
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
@@ -132,9 +100,52 @@ class ResponseWrapper(stage.PipelineStage):
else:
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
query.resp_message_chain.append(result.get_content_mirai_message_chain())
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
if result.tool_calls is not None: # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
if self.ap.platform_cfg.data['track-function-calls']:
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
else:
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)