mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
feat: pipeline invoking
This commit is contained in:
@@ -57,6 +57,12 @@ class Query(pydantic.BaseModel):
|
||||
message_chain: platform_message.MessageChain
|
||||
"""消息链,platform收到的原始消息链"""
|
||||
|
||||
pipeline_uuid: typing.Optional[str] = None
|
||||
"""流水线UUID。"""
|
||||
|
||||
pipeline_config: typing.Optional[dict[str, typing.Any]] = None
|
||||
"""流水线配置,由 Pipeline 在运行开始时设置。"""
|
||||
|
||||
adapter: msadapter.MessagePlatformAdapter
|
||||
"""消息平台适配器对象,单个app中可能启用了多个消息平台适配器,此对象表明发起此query的适配器"""
|
||||
|
||||
|
||||
@@ -50,14 +50,19 @@ class Controller:
|
||||
continue
|
||||
|
||||
if selected_query:
|
||||
async def _process_query(selected_query):
|
||||
|
||||
async def _process_query(selected_query: entities.Query):
|
||||
async with self.semaphore: # 总并发上限
|
||||
await self.process_query(selected_query)
|
||||
# find pipeline
|
||||
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(selected_query.pipeline_uuid)
|
||||
if pipeline:
|
||||
await pipeline.run(selected_query)
|
||||
|
||||
async with self.ap.query_pool:
|
||||
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
|
||||
# 通知其他协程,有新的请求可以处理了
|
||||
self.ap.query_pool.condition.notify_all()
|
||||
|
||||
self.ap.task_mgr.create_task(
|
||||
_process_query(selected_query),
|
||||
kind="query",
|
||||
@@ -70,127 +75,6 @@ class Controller:
|
||||
self.ap.logger.error(f"控制器循环出错: {e}")
|
||||
self.ap.logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
|
||||
"""检查输出
|
||||
"""
|
||||
if result.user_notice:
|
||||
# 处理str类型
|
||||
|
||||
if isinstance(result.user_notice, str):
|
||||
result.user_notice = platform_message.MessageChain(
|
||||
platform_message.Plain(result.user_notice)
|
||||
)
|
||||
elif isinstance(result.user_notice, list):
|
||||
result.user_notice = platform_message.MessageChain(
|
||||
*result.user_notice
|
||||
)
|
||||
|
||||
await self.ap.platform_mgr.send(
|
||||
query.message_event,
|
||||
result.user_notice,
|
||||
query.adapter
|
||||
)
|
||||
if result.debug_notice:
|
||||
self.ap.logger.debug(result.debug_notice)
|
||||
if result.console_notice:
|
||||
self.ap.logger.info(result.console_notice)
|
||||
if result.error_notice:
|
||||
self.ap.logger.error(result.error_notice)
|
||||
|
||||
async def _execute_from_stage(
|
||||
self,
|
||||
stage_index: int,
|
||||
query: entities.Query,
|
||||
):
|
||||
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
|
||||
|
||||
如何看懂这里为什么这么写?
|
||||
去问 GPT-4:
|
||||
Q1: 现在有一个责任链,其中有多个stage,query对象在其中传递,stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None],
|
||||
如果返回的是生成器,需要挨个生成result,检查是否result中是否要求继续,如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了,就继续生成下一个result,
|
||||
调用后续的stage,直到该生成器全部生成完。责任链中可能有多个stage会返回生成器
|
||||
Q2: 不是这样的,你可能理解有误。如果我们责任链上有这些Stage:
|
||||
|
||||
A B C D E F G
|
||||
|
||||
如果所有的stage都返回Result,且所有Result都要求继续,那么执行顺序是:
|
||||
|
||||
A B C D E F G
|
||||
|
||||
现在假设C返回的是AsyncGenerator,那么执行顺序是:
|
||||
|
||||
A B C D E F G C D E F G C D E F G ...
|
||||
Q3: 但是如果不止一个stage会返回生成器呢?
|
||||
"""
|
||||
i = stage_index
|
||||
|
||||
while i < len(self.ap.stage_mgr.stage_containers):
|
||||
stage_container = self.ap.stage_mgr.stage_containers[i]
|
||||
|
||||
query.current_stage = stage_container # 标记到 Query 对象里
|
||||
|
||||
result = stage_container.inst.process(query, stage_container.inst_name)
|
||||
|
||||
if isinstance(result, typing.Coroutine):
|
||||
result = await result
|
||||
|
||||
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}")
|
||||
await self._check_output(query, result)
|
||||
|
||||
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
break
|
||||
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = result.new_query
|
||||
elif isinstance(result, typing.AsyncGenerator): # 生成器
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen")
|
||||
|
||||
async for sub_result in result:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}")
|
||||
await self._check_output(query, sub_result)
|
||||
|
||||
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
break
|
||||
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = sub_result.new_query
|
||||
await self._execute_from_stage(i + 1, query)
|
||||
break
|
||||
|
||||
i += 1
|
||||
|
||||
async def process_query(self, query: entities.Query):
|
||||
"""处理请求
|
||||
"""
|
||||
try:
|
||||
|
||||
# ======== 触发 MessageReceived 事件 ========
|
||||
event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=event_type(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
message_chain=query.message_chain,
|
||||
query=query
|
||||
)
|
||||
)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
return
|
||||
|
||||
self.ap.logger.debug(f"Processing query {query}")
|
||||
|
||||
await self._execute_from_stage(0, query)
|
||||
except Exception as e:
|
||||
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
|
||||
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}")
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
finally:
|
||||
self.ap.logger.debug(f"Query {query} processed")
|
||||
|
||||
async def run(self):
|
||||
"""运行控制器
|
||||
"""
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ..core import app, entities
|
||||
from . import entities as pipeline_entities
|
||||
from ..entity.persistence import pipeline as persistence_pipeline
|
||||
from . import stagemgr, stage
|
||||
|
||||
from ..platform.types import message as platform_message, events as platform_events
|
||||
from ..plugin import events
|
||||
|
||||
class RuntimePipeline:
|
||||
"""运行时流水线"""
|
||||
@@ -25,8 +28,130 @@ class RuntimePipeline:
|
||||
self.pipeline_entity = pipeline_entity
|
||||
self.stage_containers = stage_containers
|
||||
|
||||
async def run(self):
|
||||
pass
|
||||
async def run(self, query: entities.Query):
|
||||
query.pipeline_config = self.pipeline_entity.config
|
||||
await self.process_query(query)
|
||||
|
||||
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
|
||||
"""检查输出
|
||||
"""
|
||||
if result.user_notice:
|
||||
# 处理str类型
|
||||
|
||||
if isinstance(result.user_notice, str):
|
||||
result.user_notice = platform_message.MessageChain(
|
||||
platform_message.Plain(result.user_notice)
|
||||
)
|
||||
elif isinstance(result.user_notice, list):
|
||||
result.user_notice = platform_message.MessageChain(
|
||||
*result.user_notice
|
||||
)
|
||||
|
||||
await self.ap.platform_mgr.send(
|
||||
query.message_event,
|
||||
result.user_notice,
|
||||
query.adapter
|
||||
)
|
||||
if result.debug_notice:
|
||||
self.ap.logger.debug(result.debug_notice)
|
||||
if result.console_notice:
|
||||
self.ap.logger.info(result.console_notice)
|
||||
if result.error_notice:
|
||||
self.ap.logger.error(result.error_notice)
|
||||
|
||||
async def _execute_from_stage(
|
||||
self,
|
||||
stage_index: int,
|
||||
query: entities.Query,
|
||||
):
|
||||
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
|
||||
|
||||
如何看懂这里为什么这么写?
|
||||
去问 GPT-4:
|
||||
Q1: 现在有一个责任链,其中有多个stage,query对象在其中传递,stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None],
|
||||
如果返回的是生成器,需要挨个生成result,检查是否result中是否要求继续,如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了,就继续生成下一个result,
|
||||
调用后续的stage,直到该生成器全部生成完。责任链中可能有多个stage会返回生成器
|
||||
Q2: 不是这样的,你可能理解有误。如果我们责任链上有这些Stage:
|
||||
|
||||
A B C D E F G
|
||||
|
||||
如果所有的stage都返回Result,且所有Result都要求继续,那么执行顺序是:
|
||||
|
||||
A B C D E F G
|
||||
|
||||
现在假设C返回的是AsyncGenerator,那么执行顺序是:
|
||||
|
||||
A B C D E F G C D E F G C D E F G ...
|
||||
Q3: 但是如果不止一个stage会返回生成器呢?
|
||||
"""
|
||||
i = stage_index
|
||||
|
||||
while i < len(self.stage_containers):
|
||||
stage_container = self.stage_containers[i]
|
||||
|
||||
query.current_stage = stage_container # 标记到 Query 对象里
|
||||
|
||||
result = stage_container.inst.process(query, stage_container.inst_name)
|
||||
|
||||
if isinstance(result, typing.Coroutine):
|
||||
result = await result
|
||||
|
||||
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}")
|
||||
await self._check_output(query, result)
|
||||
|
||||
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
break
|
||||
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = result.new_query
|
||||
elif isinstance(result, typing.AsyncGenerator): # 生成器
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen")
|
||||
|
||||
async for sub_result in result:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}")
|
||||
await self._check_output(query, sub_result)
|
||||
|
||||
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
break
|
||||
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = sub_result.new_query
|
||||
await self._execute_from_stage(i + 1, query)
|
||||
break
|
||||
|
||||
i += 1
|
||||
|
||||
async def process_query(self, query: entities.Query):
|
||||
"""处理请求
|
||||
"""
|
||||
try:
|
||||
|
||||
# ======== 触发 MessageReceived 事件 ========
|
||||
event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=event_type(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
message_chain=query.message_chain,
|
||||
query=query
|
||||
)
|
||||
)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
return
|
||||
|
||||
self.ap.logger.debug(f"Processing query {query}")
|
||||
|
||||
await self._execute_from_stage(0, query)
|
||||
except Exception as e:
|
||||
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
|
||||
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}")
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
finally:
|
||||
self.ap.logger.debug(f"Query {query} processed")
|
||||
|
||||
|
||||
class PipelineManager:
|
||||
|
||||
@@ -33,7 +33,8 @@ class QueryPool:
|
||||
sender_id: typing.Union[int, str],
|
||||
message_event: platform_events.MessageEvent,
|
||||
message_chain: platform_message.MessageChain,
|
||||
adapter: msadapter.MessagePlatformAdapter
|
||||
adapter: msadapter.MessagePlatformAdapter,
|
||||
pipeline_uuid: str
|
||||
) -> entities.Query:
|
||||
async with self.condition:
|
||||
query = entities.Query(
|
||||
@@ -43,6 +44,7 @@ class QueryPool:
|
||||
sender_id=sender_id,
|
||||
message_event=message_event,
|
||||
message_chain=message_chain,
|
||||
pipeline_uuid=pipeline_uuid,
|
||||
resp_messages=[],
|
||||
resp_message_chain=[],
|
||||
adapter=adapter
|
||||
|
||||
@@ -50,6 +50,41 @@ class RuntimeBot:
|
||||
self.adapter = adapter
|
||||
self.task_context = taskmgr.TaskContext()
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessagePlatformAdapter):
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter,
|
||||
pipeline_uuid=self.bot_entity.use_pipeline_uuid
|
||||
)
|
||||
|
||||
async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessagePlatformAdapter):
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.GROUP,
|
||||
launcher_id=event.group.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter,
|
||||
pipeline_uuid=self.bot_entity.use_pipeline_uuid
|
||||
)
|
||||
|
||||
self.adapter.register_listener(
|
||||
platform_events.FriendMessage,
|
||||
on_friend_message
|
||||
)
|
||||
self.adapter.register_listener(
|
||||
platform_events.GroupMessage,
|
||||
on_group_message
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
|
||||
async def exception_wrapper():
|
||||
@@ -135,49 +170,20 @@ class PlatformManager:
|
||||
bot_entity = persistence_bot.Bot(**bot_entity._mapping)
|
||||
elif isinstance(bot_entity, dict):
|
||||
bot_entity = persistence_bot.Bot(**bot_entity)
|
||||
|
||||
async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessagePlatformAdapter):
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter
|
||||
)
|
||||
|
||||
async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessagePlatformAdapter):
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.GROUP,
|
||||
launcher_id=event.group.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain,
|
||||
adapter=adapter
|
||||
)
|
||||
|
||||
adapter_inst = self.adapter_dict[bot_entity.adapter](
|
||||
bot_entity.adapter_config,
|
||||
self.ap
|
||||
)
|
||||
|
||||
adapter_inst.register_listener(
|
||||
platform_events.FriendMessage,
|
||||
on_friend_message
|
||||
)
|
||||
adapter_inst.register_listener(
|
||||
platform_events.GroupMessage,
|
||||
on_group_message
|
||||
)
|
||||
|
||||
runtime_bot = RuntimeBot(
|
||||
ap=self.ap,
|
||||
bot_entity=bot_entity,
|
||||
adapter=adapter_inst
|
||||
)
|
||||
|
||||
await runtime_bot.initialize()
|
||||
|
||||
self.bots.append(runtime_bot)
|
||||
|
||||
return runtime_bot
|
||||
|
||||
@@ -42,7 +42,7 @@ stages:
|
||||
zh_CN: 模型
|
||||
type: select
|
||||
required: true
|
||||
scope: llm-model
|
||||
scope: /provider/models/llm
|
||||
- name: max-round
|
||||
label:
|
||||
en_US: Max Round
|
||||
@@ -54,11 +54,9 @@ stages:
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_CN: 提示词
|
||||
type: array
|
||||
type: string
|
||||
required: true
|
||||
default: []
|
||||
items:
|
||||
type: string
|
||||
default: "You are a helpful assistant."
|
||||
- name: dify-service-api
|
||||
label:
|
||||
en_US: Dify Service API
|
||||
|
||||
@@ -46,7 +46,7 @@ stages:
|
||||
zh_CN: 窗口长度(秒)
|
||||
type: integer
|
||||
required: true
|
||||
default: 10
|
||||
default: 60
|
||||
- name: limitation
|
||||
label:
|
||||
en_US: Limitation
|
||||
|
||||
Reference in New Issue
Block a user