diff --git a/pkg/core/controller.py b/pkg/core/controller.py index 40f496a5..4c135208 100644 --- a/pkg/core/controller.py +++ b/pkg/core/controller.py @@ -66,11 +66,12 @@ class Controller: self.ap.logger.error(f"控制器循环出错: {e}") self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") - async def _check_output(self, result: pipeline_entities.StageProcessResult): + async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult): """检查输出 """ if result.user_notice: await self.ap.im_mgr.send( + query.message_event, result.user_notice ) if result.debug_notice: @@ -108,12 +109,14 @@ class Controller: while i < len(self.ap.stage_mgr.stage_containers): stage_container = self.ap.stage_mgr.stage_containers[i] - result = await stage_container.inst.process(query, stage_container.inst_name) + result = stage_container.inst.process(query, stage_container.inst_name) + if isinstance(result, typing.Coroutine): + result = await result if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果 self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}") - await self._check_output(result) + await self._check_output(query, result) if result.result_type == pipeline_entities.ResultType.INTERRUPT: self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") @@ -125,7 +128,7 @@ class Controller: async for sub_result in result: self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}") - await self._check_output(sub_result) + await self._check_output(query, sub_result) if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT: self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") diff --git a/pkg/core/entities.py b/pkg/core/entities.py index b7dccf30..8e25750a 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -27,25 +27,30 @@ class Query(pydantic.BaseModel): """一次请求的信息封装""" query_id: int - """请求ID""" + """请求ID,添加进请求池时生成""" launcher_type: LauncherTypes - """会话类型""" + """会话类型,platform设置""" launcher_id: int - """会话ID""" + """会话ID,platform设置""" sender_id: int - """发送者ID""" + """发送者ID,platform设置""" message_event: mirai.MessageEvent - """事件""" + """事件,platform收到的事件""" message_chain: mirai.MessageChain - """消息链""" + """消息链,platform收到的消息链""" + + session: typing.Optional[Session] = None + + resp_messages: typing.Optional[list[llm_entities.Message]] = [] + """由provider生成的回复消息对象列表""" resp_message_chain: typing.Optional[mirai.MessageChain] = None - """回复消息链""" + """回复消息链,从resp_messages包装而得""" class Conversation(pydantic.BaseModel): diff --git a/pkg/core/pool.py b/pkg/core/pool.py index 3d949292..a5a26423 100644 --- a/pkg/core/pool.py +++ b/pkg/core/pool.py @@ -38,7 +38,9 @@ class QueryPool: launcher_id=launcher_id, sender_id=sender_id, message_event=message_event, - message_chain=message_chain + message_chain=message_chain, + resp_messages=[], + resp_message_chain=None ) self.queries.append(query) self.query_id_counter += 1 diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index 0025b00a..78412458 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -81,31 +81,36 @@ class ContentFilterStage(stage.PipelineStage): """请求llm后处理响应 只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter """ - for filter in self.filter_chain: - if filter_entities.EnableStage.POST in filter.enable_stages: - result = await filter.process(message) + if message is None: + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + else: + message = message.strip() + for filter in self.filter_chain: + if filter_entities.EnableStage.POST in filter.enable_stages: + result = await filter.process(message) - if result.level == filter_entities.ResultLevel.BLOCK: - return entities.StageProcessResult( - result_type=entities.ResultType.INTERRUPT, - new_query=query, - user_notice=result.user_notice, - console_notice=result.console_notice - ) - elif result.level in [ - filter_entities.ResultLevel.PASS, - filter_entities.ResultLevel.MASKED - ]: - message = result.replacement + if result.level == filter_entities.ResultLevel.BLOCK: + return entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query, + user_notice=result.user_notice, + console_notice=result.console_notice + ) + elif result.level in [ + filter_entities.ResultLevel.PASS, + filter_entities.ResultLevel.MASKED + ]: + message = result.replacement - query.message_chain = mirai.MessageChain( - mirai.Plain(message) - ) + query.resp_messages[-1].content = message - return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query - ) + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) async def process( self, @@ -121,7 +126,7 @@ class ContentFilterStage(stage.PipelineStage): ) elif stage_inst_name == 'PostContentFilterStage': return await self._post_process( - str(query.message_chain).strip(), + query.resp_messages[-1].content, query ) else: diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 13b33a6f..603d1f1d 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -83,104 +83,21 @@ class ChatMessageHandler(handler.MessageHandler): ) ) - called_functions = [] - text_length = 0 start_time = time.time() async for result in conversation.use_model.requester.request(query, conversation): - conversation.messages.append(result) + query.resp_messages.append(result) if result.content is not None: text_length += len(result.content) - # 转换成可读消息 - if result.role == 'assistant': + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) - reply_text = '' - - if result.content is not None: # 有内容 - reply_text = result.content - - # ============= 触发插件事件 =============== - event_ctx = await self.ap.plugin_mgr.emit_event( - event=events.NormalMessageResponded( - launcher_type=query.launcher_type.value, - launcher_id=query.launcher_id, - sender_id=query.sender_id, - session=session, - prefix='', - response_text=reply_text, - finish_reason='stop', - funcs_called=called_functions, - query=query - ) - ) - if event_ctx.is_prevented_default(): - yield entities.StageProcessResult( - result_type=entities.ResultType.INTERRUPT, - new_query=query - ) - else: - if event_ctx.event.reply is not None: - - query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) - - else: - - query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) - - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query - ) - - if result.tool_calls is not None: # 有函数调用 - - function_names = [tc.function.name for tc in result.tool_calls] - - reply_text = f'调用函数 {".".join(function_names)}...' - - called_functions.extend(function_names) - - query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) - - if self.ap.cfg_mgr.data['trace_function_calls']: - - event_ctx = await self.ap.plugin_mgr.emit_event( - event=events.NormalMessageResponded( - launcher_type=query.launcher_type.value, - launcher_id=query.launcher_id, - sender_id=query.sender_id, - session=session, - prefix='', - response_text=reply_text, - finish_reason='stop', - funcs_called=called_functions, - query=query - ) - ) - - if event_ctx.is_prevented_default(): - yield entities.StageProcessResult( - result_type=entities.ResultType.INTERRUPT, - new_query=query - ) - else: - if event_ctx.event.reply is not None: - - query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) - - else: - - query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) - - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query - ) - await self.ap.ctr_mgr.usage.post_query_record( session_type=session.launcher_type.value, session_id=str(session.launcher_id), diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py index 60543dc3..50600a36 100644 --- a/pkg/pipeline/process/handlers/command.py +++ b/pkg/pipeline/process/handlers/command.py @@ -6,6 +6,7 @@ import mirai from .. import handler from ... import entities from ....core import entities as core_entities +from ....provider import entities as llm_entities from ....plugin import events @@ -44,7 +45,14 @@ class CommandHandler(handler.MessageHandler): if event_ctx.is_prevented_default(): if event_ctx.event.reply is not None: - query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) + mc = mirai.MessageChain(event_ctx.event.reply) + + query.resp_messages.append( + llm_entities.Message( + role='command', + content=str(mc), + ) + ) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, @@ -73,18 +81,30 @@ class CommandHandler(handler.MessageHandler): session=session ): if ret.error is not None: - query.resp_message_chain = mirai.MessageChain([ - mirai.Plain(str(ret.error)) - ]) + # query.resp_message_chain = mirai.MessageChain([ + # mirai.Plain(str(ret.error)) + # ]) + query.resp_messages.append( + llm_entities.Message( + role='command', + content=str(ret.error), + ) + ) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query ) elif ret.text is not None: - query.resp_message_chain = mirai.MessageChain([ - mirai.Plain(ret.text) - ]) + # query.resp_message_chain = mirai.MessageChain([ + # mirai.Plain(ret.text) + # ]) + query.resp_messages.append( + llm_entities.Message( + role='command', + content=ret.text, + ) + ) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, diff --git a/pkg/pipeline/stagemgr.py b/pkg/pipeline/stagemgr.py index 1ff36329..24cb20ff 100644 --- a/pkg/pipeline/stagemgr.py +++ b/pkg/pipeline/stagemgr.py @@ -10,6 +10,7 @@ from .cntfilter import cntfilter from .process import process from .longtext import longtext from .respback import respback +from .wrapper import wrapper stage_order = [ @@ -18,6 +19,7 @@ stage_order = [ "PreContentFilterStage", "MessageProcessor", "PostContentFilterStage", + "ResponseWrapper", "LongTextProcessStage", "SendResponseBackStage", ] diff --git a/pkg/pipeline/wrapper/__init__.py b/pkg/pipeline/wrapper/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py new file mode 100644 index 00000000..e333f006 --- /dev/null +++ b/pkg/pipeline/wrapper/wrapper.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import typing + +import mirai + +from ...core import app, entities as core_entities +from .. import entities +from .. import stage, entities, stagemgr +from ...core import entities as core_entities +from ...config import manager as cfg_mgr +from ...plugin import events + + +@stage.stage_class("ResponseWrapper") +class ResponseWrapper(stage.PipelineStage): + + async def initialize(self): + pass + + async def process( + self, + query: core_entities.Query, + stage_inst_name: str, + ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: + """处理 + """ + + if query.resp_messages[-1].role == 'command': + query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + elif query.resp_messages[-1].role == 'assistant': + result = query.resp_messages[-1] + session = await self.ap.sess_mgr.get_session(query) + + reply_text = '' + + if result.content is not None: # 有内容 + reply_text = result.content + + # ============= 触发插件事件 =============== + event_ctx = await self.ap.plugin_mgr.emit_event( + event=events.NormalMessageResponded( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + session=session, + prefix='', + response_text=reply_text, + finish_reason='stop', + funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [], + query=query + ) + ) + if event_ctx.is_prevented_default(): + yield entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query + ) + else: + if event_ctx.event.reply is not None: + + query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) + + else: + + query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + + if result.tool_calls is not None: # 有函数调用 + + function_names = [tc.function.name for tc in result.tool_calls] + + reply_text = f'调用函数 {".".join(function_names)}...' + + query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) + + if self.ap.cfg_mgr.data['trace_function_calls']: + + event_ctx = await self.ap.plugin_mgr.emit_event( + event=events.NormalMessageResponded( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + session=session, + prefix='', + response_text=reply_text, + finish_reason='stop', + funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [], + query=query + ) + ) + + if event_ctx.is_prevented_default(): + yield entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query + ) + else: + if event_ctx.event.reply is not None: + + query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) + + else: + + query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) \ No newline at end of file diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index 741d3a23..384432f9 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -105,7 +105,7 @@ class PlatformManager: launcher_id=event.sender.id, sender_id=event.sender.id, message_event=event, - message_chain=event.message_chain + message_chain=event.message_chain, ) # nakuru不区分好友和陌生人,故仅为yirimirai注册陌生人事件 diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index 2dd5804b..44866e2e 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -20,7 +20,7 @@ class ToolCall(pydantic.BaseModel): class Message(pydantic.BaseModel): - role: str + role: str # user, system, assistant, tool, command name: typing.Optional[str] = None