refactor: 分隔LLM请求过程和消息封装过程

This commit is contained in:
RockChinQ
2024-02-01 15:48:26 +08:00
parent 32162afa65
commit 976a9de39c
11 changed files with 205 additions and 132 deletions

View File

@@ -66,11 +66,12 @@ class Controller:
self.ap.logger.error(f"控制器循环出错: {e}")
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
async def _check_output(self, result: pipeline_entities.StageProcessResult):
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
"""检查输出
"""
if result.user_notice:
await self.ap.im_mgr.send(
query.message_event,
result.user_notice
)
if result.debug_notice:
@@ -108,12 +109,14 @@ class Controller:
while i < len(self.ap.stage_mgr.stage_containers):
stage_container = self.ap.stage_mgr.stage_containers[i]
result = await stage_container.inst.process(query, stage_container.inst_name)
result = stage_container.inst.process(query, stage_container.inst_name)
if isinstance(result, typing.Coroutine):
result = await result
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}")
await self._check_output(result)
await self._check_output(query, result)
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
@@ -125,7 +128,7 @@ class Controller:
async for sub_result in result:
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}")
await self._check_output(sub_result)
await self._check_output(query, sub_result)
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")

View File

@@ -27,25 +27,30 @@ class Query(pydantic.BaseModel):
"""一次请求的信息封装"""
query_id: int
"""请求ID"""
"""请求ID,添加进请求池时生成"""
launcher_type: LauncherTypes
"""会话类型"""
"""会话类型platform设置"""
launcher_id: int
"""会话ID"""
"""会话IDplatform设置"""
sender_id: int
"""发送者ID"""
"""发送者IDplatform设置"""
message_event: mirai.MessageEvent
"""事件"""
"""事件platform收到的事件"""
message_chain: mirai.MessageChain
"""消息链"""
"""消息链platform收到的消息链"""
session: typing.Optional[Session] = None
resp_messages: typing.Optional[list[llm_entities.Message]] = []
"""由provider生成的回复消息对象列表"""
resp_message_chain: typing.Optional[mirai.MessageChain] = None
"""回复消息链"""
"""回复消息链从resp_messages包装而得"""
class Conversation(pydantic.BaseModel):

View File

@@ -38,7 +38,9 @@ class QueryPool:
launcher_id=launcher_id,
sender_id=sender_id,
message_event=message_event,
message_chain=message_chain
message_chain=message_chain,
resp_messages=[],
resp_message_chain=None
)
self.queries.append(query)
self.query_id_counter += 1

View File

@@ -81,31 +81,36 @@ class ContentFilterStage(stage.PipelineStage):
"""请求llm后处理响应
只要是 PASS 或者 MASKED 的就通过此 filter将其 replacement 设置为message进入下一个 filter
"""
for filter in self.filter_chain:
if filter_entities.EnableStage.POST in filter.enable_stages:
result = await filter.process(message)
if message is None:
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
message = message.strip()
for filter in self.filter_chain:
if filter_entities.EnableStage.POST in filter.enable_stages:
result = await filter.process(message)
if result.level == filter_entities.ResultLevel.BLOCK:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice=result.user_notice,
console_notice=result.console_notice
)
elif result.level in [
filter_entities.ResultLevel.PASS,
filter_entities.ResultLevel.MASKED
]:
message = result.replacement
if result.level == filter_entities.ResultLevel.BLOCK:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice=result.user_notice,
console_notice=result.console_notice
)
elif result.level in [
filter_entities.ResultLevel.PASS,
filter_entities.ResultLevel.MASKED
]:
message = result.replacement
query.message_chain = mirai.MessageChain(
mirai.Plain(message)
)
query.resp_messages[-1].content = message
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
async def process(
self,
@@ -121,7 +126,7 @@ class ContentFilterStage(stage.PipelineStage):
)
elif stage_inst_name == 'PostContentFilterStage':
return await self._post_process(
str(query.message_chain).strip(),
query.resp_messages[-1].content,
query
)
else:

View File

@@ -83,104 +83,21 @@ class ChatMessageHandler(handler.MessageHandler):
)
)
called_functions = []
text_length = 0
start_time = time.time()
async for result in conversation.use_model.requester.request(query, conversation):
conversation.messages.append(result)
query.resp_messages.append(result)
if result.content is not None:
text_length += len(result.content)
# 转换成可读消息
if result.role == 'assistant':
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
reply_text = ''
if result.content is not None: # 有内容
reply_text = result.content
# ============= 触发插件事件 ===============
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=called_functions,
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
else:
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
if result.tool_calls is not None: # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
called_functions.extend(function_names)
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
if self.ap.cfg_mgr.data['trace_function_calls']:
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=called_functions,
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
else:
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
await self.ap.ctr_mgr.usage.post_query_record(
session_type=session.launcher_type.value,
session_id=str(session.launcher_id),

View File

@@ -6,6 +6,7 @@ import mirai
from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....provider import entities as llm_entities
from ....plugin import events
@@ -44,7 +45,14 @@ class CommandHandler(handler.MessageHandler):
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
mc = mirai.MessageChain(event_ctx.event.reply)
query.resp_messages.append(
llm_entities.Message(
role='command',
content=str(mc),
)
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
@@ -73,18 +81,30 @@ class CommandHandler(handler.MessageHandler):
session=session
):
if ret.error is not None:
query.resp_message_chain = mirai.MessageChain([
mirai.Plain(str(ret.error))
])
# query.resp_message_chain = mirai.MessageChain([
# mirai.Plain(str(ret.error))
# ])
query.resp_messages.append(
llm_entities.Message(
role='command',
content=str(ret.error),
)
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
elif ret.text is not None:
query.resp_message_chain = mirai.MessageChain([
mirai.Plain(ret.text)
])
# query.resp_message_chain = mirai.MessageChain([
# mirai.Plain(ret.text)
# ])
query.resp_messages.append(
llm_entities.Message(
role='command',
content=ret.text,
)
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,

View File

@@ -10,6 +10,7 @@ from .cntfilter import cntfilter
from .process import process
from .longtext import longtext
from .respback import respback
from .wrapper import wrapper
stage_order = [
@@ -18,6 +19,7 @@ stage_order = [
"PreContentFilterStage",
"MessageProcessor",
"PostContentFilterStage",
"ResponseWrapper",
"LongTextProcessStage",
"SendResponseBackStage",
]

View File

View File

@@ -0,0 +1,119 @@
from __future__ import annotations
import typing
import mirai
from ...core import app, entities as core_entities
from .. import entities
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from ...plugin import events
@stage.stage_class("ResponseWrapper")
class ResponseWrapper(stage.PipelineStage):
async def initialize(self):
pass
async def process(
self,
query: core_entities.Query,
stage_inst_name: str,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
if query.resp_messages[-1].role == 'command':
query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
elif query.resp_messages[-1].role == 'assistant':
result = query.resp_messages[-1]
session = await self.ap.sess_mgr.get_session(query)
reply_text = ''
if result.content is not None: # 有内容
reply_text = result.content
# ============= 触发插件事件 ===============
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
else:
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
if result.tool_calls is not None: # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
if self.ap.cfg_mgr.data['trace_function_calls']:
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
query=query
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
else:
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@@ -105,7 +105,7 @@ class PlatformManager:
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain
message_chain=event.message_chain,
)
# nakuru不区分好友和陌生人故仅为yirimirai注册陌生人事件

View File

@@ -20,7 +20,7 @@ class ToolCall(pydantic.BaseModel):
class Message(pydantic.BaseModel):
role: str
role: str # user, system, assistant, tool, command
name: typing.Optional[str] = None