mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-04 12:56:02 +00:00
refactor: 分隔LLM请求过程和消息封装过程
This commit is contained in:
@@ -66,11 +66,12 @@ class Controller:
|
||||
self.ap.logger.error(f"控制器循环出错: {e}")
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
async def _check_output(self, result: pipeline_entities.StageProcessResult):
|
||||
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
|
||||
"""检查输出
|
||||
"""
|
||||
if result.user_notice:
|
||||
await self.ap.im_mgr.send(
|
||||
query.message_event,
|
||||
result.user_notice
|
||||
)
|
||||
if result.debug_notice:
|
||||
@@ -108,12 +109,14 @@ class Controller:
|
||||
while i < len(self.ap.stage_mgr.stage_containers):
|
||||
stage_container = self.ap.stage_mgr.stage_containers[i]
|
||||
|
||||
result = await stage_container.inst.process(query, stage_container.inst_name)
|
||||
result = stage_container.inst.process(query, stage_container.inst_name)
|
||||
|
||||
if isinstance(result, typing.Coroutine):
|
||||
result = await result
|
||||
|
||||
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}")
|
||||
await self._check_output(result)
|
||||
await self._check_output(query, result)
|
||||
|
||||
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
@@ -125,7 +128,7 @@ class Controller:
|
||||
|
||||
async for sub_result in result:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}")
|
||||
await self._check_output(sub_result)
|
||||
await self._check_output(query, sub_result)
|
||||
|
||||
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
|
||||
@@ -27,25 +27,30 @@ class Query(pydantic.BaseModel):
|
||||
"""一次请求的信息封装"""
|
||||
|
||||
query_id: int
|
||||
"""请求ID"""
|
||||
"""请求ID,添加进请求池时生成"""
|
||||
|
||||
launcher_type: LauncherTypes
|
||||
"""会话类型"""
|
||||
"""会话类型,platform设置"""
|
||||
|
||||
launcher_id: int
|
||||
"""会话ID"""
|
||||
"""会话ID,platform设置"""
|
||||
|
||||
sender_id: int
|
||||
"""发送者ID"""
|
||||
"""发送者ID,platform设置"""
|
||||
|
||||
message_event: mirai.MessageEvent
|
||||
"""事件"""
|
||||
"""事件,platform收到的事件"""
|
||||
|
||||
message_chain: mirai.MessageChain
|
||||
"""消息链"""
|
||||
"""消息链,platform收到的消息链"""
|
||||
|
||||
session: typing.Optional[Session] = None
|
||||
|
||||
resp_messages: typing.Optional[list[llm_entities.Message]] = []
|
||||
"""由provider生成的回复消息对象列表"""
|
||||
|
||||
resp_message_chain: typing.Optional[mirai.MessageChain] = None
|
||||
"""回复消息链"""
|
||||
"""回复消息链,从resp_messages包装而得"""
|
||||
|
||||
|
||||
class Conversation(pydantic.BaseModel):
|
||||
|
||||
@@ -38,7 +38,9 @@ class QueryPool:
|
||||
launcher_id=launcher_id,
|
||||
sender_id=sender_id,
|
||||
message_event=message_event,
|
||||
message_chain=message_chain
|
||||
message_chain=message_chain,
|
||||
resp_messages=[],
|
||||
resp_message_chain=None
|
||||
)
|
||||
self.queries.append(query)
|
||||
self.query_id_counter += 1
|
||||
|
||||
@@ -81,31 +81,36 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
"""请求llm后处理响应
|
||||
只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter
|
||||
"""
|
||||
for filter in self.filter_chain:
|
||||
if filter_entities.EnableStage.POST in filter.enable_stages:
|
||||
result = await filter.process(message)
|
||||
if message is None:
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
message = message.strip()
|
||||
for filter in self.filter_chain:
|
||||
if filter_entities.EnableStage.POST in filter.enable_stages:
|
||||
result = await filter.process(message)
|
||||
|
||||
if result.level == filter_entities.ResultLevel.BLOCK:
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
user_notice=result.user_notice,
|
||||
console_notice=result.console_notice
|
||||
)
|
||||
elif result.level in [
|
||||
filter_entities.ResultLevel.PASS,
|
||||
filter_entities.ResultLevel.MASKED
|
||||
]:
|
||||
message = result.replacement
|
||||
if result.level == filter_entities.ResultLevel.BLOCK:
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
user_notice=result.user_notice,
|
||||
console_notice=result.console_notice
|
||||
)
|
||||
elif result.level in [
|
||||
filter_entities.ResultLevel.PASS,
|
||||
filter_entities.ResultLevel.MASKED
|
||||
]:
|
||||
message = result.replacement
|
||||
|
||||
query.message_chain = mirai.MessageChain(
|
||||
mirai.Plain(message)
|
||||
)
|
||||
query.resp_messages[-1].content = message
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
@@ -121,7 +126,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
)
|
||||
elif stage_inst_name == 'PostContentFilterStage':
|
||||
return await self._post_process(
|
||||
str(query.message_chain).strip(),
|
||||
query.resp_messages[-1].content,
|
||||
query
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -83,104 +83,21 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
)
|
||||
)
|
||||
|
||||
called_functions = []
|
||||
|
||||
text_length = 0
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
async for result in conversation.use_model.requester.request(query, conversation):
|
||||
conversation.messages.append(result)
|
||||
query.resp_messages.append(result)
|
||||
|
||||
if result.content is not None:
|
||||
text_length += len(result.content)
|
||||
|
||||
# 转换成可读消息
|
||||
if result.role == 'assistant':
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
reply_text = ''
|
||||
|
||||
if result.content is not None: # 有内容
|
||||
reply_text = result.content
|
||||
|
||||
# ============= 触发插件事件 ===============
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
session=session,
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=called_functions,
|
||||
query=query
|
||||
)
|
||||
)
|
||||
if event_ctx.is_prevented_default():
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
|
||||
|
||||
else:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
if result.tool_calls is not None: # 有函数调用
|
||||
|
||||
function_names = [tc.function.name for tc in result.tool_calls]
|
||||
|
||||
reply_text = f'调用函数 {".".join(function_names)}...'
|
||||
|
||||
called_functions.extend(function_names)
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
|
||||
if self.ap.cfg_mgr.data['trace_function_calls']:
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
session=session,
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=called_functions,
|
||||
query=query
|
||||
)
|
||||
)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
|
||||
|
||||
else:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
await self.ap.ctr_mgr.usage.post_query_record(
|
||||
session_type=session.launcher_type.value,
|
||||
session_id=str(session.launcher_id),
|
||||
|
||||
@@ -6,6 +6,7 @@ import mirai
|
||||
from .. import handler
|
||||
from ... import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....provider import entities as llm_entities
|
||||
from ....plugin import events
|
||||
|
||||
|
||||
@@ -44,7 +45,14 @@ class CommandHandler(handler.MessageHandler):
|
||||
if event_ctx.is_prevented_default():
|
||||
|
||||
if event_ctx.event.reply is not None:
|
||||
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
|
||||
mc = mirai.MessageChain(event_ctx.event.reply)
|
||||
|
||||
query.resp_messages.append(
|
||||
llm_entities.Message(
|
||||
role='command',
|
||||
content=str(mc),
|
||||
)
|
||||
)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
@@ -73,18 +81,30 @@ class CommandHandler(handler.MessageHandler):
|
||||
session=session
|
||||
):
|
||||
if ret.error is not None:
|
||||
query.resp_message_chain = mirai.MessageChain([
|
||||
mirai.Plain(str(ret.error))
|
||||
])
|
||||
# query.resp_message_chain = mirai.MessageChain([
|
||||
# mirai.Plain(str(ret.error))
|
||||
# ])
|
||||
query.resp_messages.append(
|
||||
llm_entities.Message(
|
||||
role='command',
|
||||
content=str(ret.error),
|
||||
)
|
||||
)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
elif ret.text is not None:
|
||||
query.resp_message_chain = mirai.MessageChain([
|
||||
mirai.Plain(ret.text)
|
||||
])
|
||||
# query.resp_message_chain = mirai.MessageChain([
|
||||
# mirai.Plain(ret.text)
|
||||
# ])
|
||||
query.resp_messages.append(
|
||||
llm_entities.Message(
|
||||
role='command',
|
||||
content=ret.text,
|
||||
)
|
||||
)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
|
||||
@@ -10,6 +10,7 @@ from .cntfilter import cntfilter
|
||||
from .process import process
|
||||
from .longtext import longtext
|
||||
from .respback import respback
|
||||
from .wrapper import wrapper
|
||||
|
||||
|
||||
stage_order = [
|
||||
@@ -18,6 +19,7 @@ stage_order = [
|
||||
"PreContentFilterStage",
|
||||
"MessageProcessor",
|
||||
"PostContentFilterStage",
|
||||
"ResponseWrapper",
|
||||
"LongTextProcessStage",
|
||||
"SendResponseBackStage",
|
||||
]
|
||||
|
||||
0
pkg/pipeline/wrapper/__init__.py
Normal file
0
pkg/pipeline/wrapper/__init__.py
Normal file
119
pkg/pipeline/wrapper/wrapper.py
Normal file
119
pkg/pipeline/wrapper/wrapper.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import mirai
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from .. import entities
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from ...plugin import events
|
||||
|
||||
|
||||
@stage.stage_class("ResponseWrapper")
|
||||
class ResponseWrapper(stage.PipelineStage):
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
stage_inst_name: str,
|
||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||
"""处理
|
||||
"""
|
||||
|
||||
if query.resp_messages[-1].role == 'command':
|
||||
query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
elif query.resp_messages[-1].role == 'assistant':
|
||||
result = query.resp_messages[-1]
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
reply_text = ''
|
||||
|
||||
if result.content is not None: # 有内容
|
||||
reply_text = result.content
|
||||
|
||||
# ============= 触发插件事件 ===============
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
session=session,
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
|
||||
query=query
|
||||
)
|
||||
)
|
||||
if event_ctx.is_prevented_default():
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
|
||||
|
||||
else:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
if result.tool_calls is not None: # 有函数调用
|
||||
|
||||
function_names = [tc.function.name for tc in result.tool_calls]
|
||||
|
||||
reply_text = f'调用函数 {".".join(function_names)}...'
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
|
||||
if self.ap.cfg_mgr.data['trace_function_calls']:
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
session=session,
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
|
||||
query=query
|
||||
)
|
||||
)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
|
||||
|
||||
else:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
@@ -105,7 +105,7 @@ class PlatformManager:
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain
|
||||
message_chain=event.message_chain,
|
||||
)
|
||||
|
||||
# nakuru不区分好友和陌生人,故仅为yirimirai注册陌生人事件
|
||||
|
||||
@@ -20,7 +20,7 @@ class ToolCall(pydantic.BaseModel):
|
||||
|
||||
|
||||
class Message(pydantic.BaseModel):
|
||||
role: str
|
||||
role: str # user, system, assistant, tool, command
|
||||
|
||||
name: typing.Optional[str] = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user