mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-27 07:54:19 +00:00
feat: preliminarily implement pipeline invoking
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import re
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
@@ -13,7 +13,7 @@ class BanSessionCheckStage(stage.PipelineStage):
|
||||
仅检查query中群号或个人号是否在访问控制列表中。
|
||||
"""
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
pass
|
||||
|
||||
async def process(
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from ...core import app
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from . import filter as filter_model, entities as filter_entities
|
||||
@@ -35,7 +35,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
self.filter_chain = []
|
||||
super().__init__(ap)
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
|
||||
filters_required = [
|
||||
"content-ignore",
|
||||
|
||||
@@ -54,9 +54,13 @@ class Controller:
|
||||
async def _process_query(selected_query: entities.Query):
|
||||
async with self.semaphore: # 总并发上限
|
||||
# find pipeline
|
||||
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(selected_query.pipeline_uuid)
|
||||
if pipeline:
|
||||
await pipeline.run(selected_query)
|
||||
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
|
||||
# Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected.
|
||||
bot = await self.ap.platform_mgr.get_bot_by_uuid(selected_query.bot_uuid)
|
||||
if bot:
|
||||
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(bot.bot_entity.use_pipeline_uuid)
|
||||
if pipeline:
|
||||
await pipeline.run(selected_query)
|
||||
|
||||
async with self.ap.query_pool:
|
||||
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
|
||||
|
||||
@@ -7,7 +7,7 @@ from PIL import Image, ImageDraw, ImageFont
|
||||
from ...core import app
|
||||
from . import strategy
|
||||
from .strategies import image, forward
|
||||
from .. import stage, entities, stagemgr
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from ...platform.types import message as platform_message
|
||||
@@ -23,8 +23,8 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
|
||||
strategy_impl: strategy.LongTextStrategy
|
||||
|
||||
async def initialize(self):
|
||||
config = self.ap.platform_cfg.data['long-text-process']
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
config = pipeline_config['output']['long-text-processing']
|
||||
if config['strategy'] == 'image':
|
||||
use_font = config['font-path']
|
||||
try:
|
||||
@@ -42,12 +42,12 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
else:
|
||||
self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。")
|
||||
|
||||
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
|
||||
pipeline_config['output']['long-text-processing']['strategy'] = "forward"
|
||||
except:
|
||||
traceback.print_exc()
|
||||
self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。".format(use_font))
|
||||
|
||||
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
|
||||
pipeline_config['output']['long-text-processing']['strategy'] = "forward"
|
||||
|
||||
for strategy_cls in strategy.preregistered_strategies:
|
||||
if strategy_cls.name == config['strategy']:
|
||||
@@ -69,7 +69,7 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
|
||||
if contains_non_plain:
|
||||
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
|
||||
elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']:
|
||||
elif len(str(query.resp_message_chain[-1])) > query.pipeline_config['output']['long-text-processing']['threshold']:
|
||||
query.resp_message_chain[-1] = platform_message.MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
|
||||
|
||||
return entities.StageProcessResult(
|
||||
|
||||
@@ -8,6 +8,7 @@ import re
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
import functools
|
||||
from ....platform.types import message as platform_message
|
||||
|
||||
from .. import strategy as strategy_model
|
||||
@@ -17,15 +18,18 @@ from ....core import entities as core_entities
|
||||
@strategy_model.strategy_class("image")
|
||||
class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
|
||||
text_render_font: ImageFont.FreeTypeFont
|
||||
|
||||
async def initialize(self):
|
||||
self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8")
|
||||
pass
|
||||
|
||||
@functools.lru_cache(maxsize=16)
|
||||
def get_font(self, query: core_entities.Query):
|
||||
return ImageFont.truetype(query.pipeline_config['output']['long-text-processing']['font-path'], 32, encoding="utf-8")
|
||||
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
img_path = self.text_to_image(
|
||||
text_str=message,
|
||||
save_as='temp/{}.png'.format(int(time.time()))
|
||||
save_as='temp/{}.png'.format(int(time.time())),
|
||||
query=query
|
||||
)
|
||||
|
||||
compressed_path, size = self.compress_image(
|
||||
@@ -127,7 +131,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
return outfile, self.get_size(outfile)
|
||||
|
||||
|
||||
def text_to_image(self, text_str: str, save_as="temp.png", width=800):
|
||||
def text_to_image(self, text_str: str, save_as="temp.png", width=800, query: core_entities.Query = None):
|
||||
|
||||
text_str = text_str.replace("\t", " ")
|
||||
|
||||
@@ -142,7 +146,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width))
|
||||
for line in lines:
|
||||
# 如果长了就分割
|
||||
line_width = self.text_render_font.getlength(line)
|
||||
line_width = self.get_font(query).getlength(line)
|
||||
self.ap.logger.debug("line_width: {}".format(line_width))
|
||||
if line_width < text_width:
|
||||
final_lines.append(line)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from . import truncator
|
||||
from .truncators import round
|
||||
@@ -14,7 +14,7 @@ class ConversationMessageTruncator(stage.PipelineStage):
|
||||
"""
|
||||
trun: truncator.Truncator
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
use_method = self.ap.pipeline_cfg.data['msg-truncate']['method']
|
||||
|
||||
for trun in truncator.preregistered_truncators:
|
||||
|
||||
@@ -8,10 +8,35 @@ import sqlalchemy
|
||||
from ..core import app, entities
|
||||
from . import entities as pipeline_entities
|
||||
from ..entity.persistence import pipeline as persistence_pipeline
|
||||
from . import stagemgr, stage
|
||||
from . import stage
|
||||
from ..platform.types import message as platform_message, events as platform_events
|
||||
from ..plugin import events
|
||||
|
||||
from .resprule import resprule
|
||||
from .bansess import bansess
|
||||
from .cntfilter import cntfilter
|
||||
from .process import process
|
||||
from .longtext import longtext
|
||||
from .respback import respback
|
||||
from .wrapper import wrapper
|
||||
from .preproc import preproc
|
||||
from .ratelimit import ratelimit
|
||||
from .msgtrun import msgtrun
|
||||
|
||||
|
||||
class StageInstContainer():
|
||||
"""阶段实例容器
|
||||
"""
|
||||
|
||||
inst_name: str
|
||||
|
||||
inst: stage.PipelineStage
|
||||
|
||||
def __init__(self, inst_name: str, inst: stage.PipelineStage):
|
||||
self.inst_name = inst_name
|
||||
self.inst = inst
|
||||
|
||||
|
||||
class RuntimePipeline:
|
||||
"""运行时流水线"""
|
||||
|
||||
@@ -20,10 +45,10 @@ class RuntimePipeline:
|
||||
pipeline_entity: persistence_pipeline.LegacyPipeline
|
||||
"""流水线实体"""
|
||||
|
||||
stage_containers: list[stagemgr.StageInstContainer]
|
||||
stage_containers: list[StageInstContainer]
|
||||
"""阶段实例容器"""
|
||||
|
||||
def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[stagemgr.StageInstContainer]):
|
||||
def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[StageInstContainer]):
|
||||
self.ap = ap
|
||||
self.pipeline_entity = pipeline_entity
|
||||
self.stage_containers = stage_containers
|
||||
@@ -47,10 +72,18 @@ class RuntimePipeline:
|
||||
*result.user_notice
|
||||
)
|
||||
|
||||
await self.ap.platform_mgr.send(
|
||||
query.message_event,
|
||||
result.user_notice,
|
||||
query.adapter
|
||||
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage):
|
||||
result.user_notice.insert(
|
||||
0,
|
||||
platform_message.At(
|
||||
query.message_event.sender.id
|
||||
)
|
||||
)
|
||||
|
||||
await query.adapter.reply_message(
|
||||
message_source=query.message_event,
|
||||
message=result.user_notice,
|
||||
quote_origin=query.pipeline_config['output']['misc']['quote-origin']
|
||||
)
|
||||
if result.debug_notice:
|
||||
self.ap.logger.debug(result.debug_notice)
|
||||
@@ -195,12 +228,15 @@ class PipelineManager:
|
||||
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
|
||||
|
||||
# initialize stage containers according to pipeline_entity.stages
|
||||
stage_containers = []
|
||||
stage_containers: list[StageInstContainer] = []
|
||||
for stage_name in pipeline_entity.stages:
|
||||
stage_containers.append(stagemgr.StageInstContainer(
|
||||
stage_containers.append(StageInstContainer(
|
||||
inst_name=stage_name,
|
||||
inst=self.stage_dict[stage_name](self.ap)
|
||||
))
|
||||
|
||||
for stage_container in stage_containers:
|
||||
await stage_container.inst.initialize(pipeline_entity.config)
|
||||
|
||||
runtime_pipeline = RuntimePipeline(self.ap, pipeline_entity, stage_containers)
|
||||
self.pipelines.append(runtime_pipeline)
|
||||
|
||||
@@ -28,23 +28,23 @@ class QueryPool:
|
||||
|
||||
async def add_query(
|
||||
self,
|
||||
bot_uuid: str,
|
||||
launcher_type: entities.LauncherTypes,
|
||||
launcher_id: typing.Union[int, str],
|
||||
sender_id: typing.Union[int, str],
|
||||
message_event: platform_events.MessageEvent,
|
||||
message_chain: platform_message.MessageChain,
|
||||
adapter: msadapter.MessagePlatformAdapter,
|
||||
pipeline_uuid: str
|
||||
) -> entities.Query:
|
||||
async with self.condition:
|
||||
query = entities.Query(
|
||||
bot_uuid=bot_uuid,
|
||||
query_id=self.query_id_counter,
|
||||
launcher_type=launcher_type,
|
||||
launcher_id=launcher_id,
|
||||
sender_id=sender_id,
|
||||
message_event=message_event,
|
||||
message_chain=message_chain,
|
||||
pipeline_uuid=pipeline_uuid,
|
||||
resp_messages=[],
|
||||
resp_message_chain=[],
|
||||
adapter=adapter
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...provider import entities as llm_entities
|
||||
from ...plugin import events
|
||||
@@ -33,16 +33,16 @@ class PreProcessor(stage.PipelineStage):
|
||||
"""
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
conversation = await self.ap.sess_mgr.get_conversation(session)
|
||||
conversation = await self.ap.sess_mgr.get_conversation(query, session)
|
||||
|
||||
# 设置query
|
||||
query.session = session
|
||||
query.prompt = conversation.prompt.copy()
|
||||
query.messages = conversation.messages.copy()
|
||||
|
||||
query.use_model = conversation.use_model
|
||||
query.use_llm_model = conversation.use_llm_model
|
||||
|
||||
query.use_funcs = conversation.use_funcs if query.use_model.tool_call_supported else None
|
||||
query.use_funcs = conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('tool_call') else None
|
||||
|
||||
query.variables = {
|
||||
"session_id": f"{query.session.launcher_type.value}_{query.session.launcher_id}",
|
||||
@@ -50,8 +50,9 @@ class PreProcessor(stage.PipelineStage):
|
||||
"msg_create_time": int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp()),
|
||||
}
|
||||
|
||||
# 检查vision是否启用,没启用就删除所有图片
|
||||
if not self.ap.provider_cfg.data['enable-vision'] or (self.ap.provider_cfg.data['runner'] == 'local-agent' and not query.use_model.vision_supported):
|
||||
# Check if this model supports vision, if not, remove all images
|
||||
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
|
||||
if query.pipeline_config['ai']['runner']['runner'] == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'):
|
||||
for msg in query.messages:
|
||||
if isinstance(msg.content, list):
|
||||
for me in msg.content:
|
||||
@@ -69,7 +70,7 @@ class PreProcessor(stage.PipelineStage):
|
||||
)
|
||||
plain_text += me.text
|
||||
elif isinstance(me, platform_message.Image):
|
||||
if self.ap.provider_cfg.data['enable-vision'] and (self.ap.provider_cfg.data['runner'] != 'local-agent' or query.use_model.vision_supported):
|
||||
if query.pipeline_config['ai']['runner']['runner'] != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__('vision'):
|
||||
if me.base64 is not None:
|
||||
content_list.append(
|
||||
llm_entities.ContentElement.from_image_base64(me.base64)
|
||||
|
||||
@@ -9,7 +9,9 @@ import json
|
||||
from .. import handler
|
||||
from ... import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....provider import entities as llm_entities, runnermgr
|
||||
from ....provider import entities as llm_entities
|
||||
from ....provider import runner as runner_module
|
||||
from ....provider.runners import localagent, difysvapi, dashscopeapi
|
||||
from ....plugin import events
|
||||
|
||||
from ....platform.types import message as platform_message
|
||||
@@ -56,12 +58,6 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
)
|
||||
else:
|
||||
|
||||
if not self.ap.provider_cfg.data['enable-chat']:
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
)
|
||||
|
||||
if event_ctx.event.alter is not None:
|
||||
# if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
|
||||
query.user_message.content = event_ctx.event.alter
|
||||
@@ -72,7 +68,12 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
|
||||
try:
|
||||
|
||||
runner = self.ap.runner_mgr.get_runner()
|
||||
for r in runner_module.preregistered_runners:
|
||||
if r.name == query.pipeline_config["ai"]["runner"]["runner"]:
|
||||
runner = r(self.ap, query.pipeline_config)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"未找到请求运行器: {query.pipeline_config['ai']['runner']['runner']}")
|
||||
|
||||
async for result in runner.run(query):
|
||||
query.resp_messages.append(result)
|
||||
@@ -93,10 +94,12 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
|
||||
self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}')
|
||||
|
||||
hide_exception_info = query.pipeline_config['output']['misc']['hide-exception']
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
user_notice='请求失败' if self.ap.platform_cfg.data['hide-exception-info'] else f'{e}',
|
||||
user_notice='请求失败' if hide_exception_info else f'{e}',
|
||||
error_notice=f'{e}',
|
||||
debug_notice=traceback.format_exc()
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from ...core import app, entities as core_entities
|
||||
from . import handler
|
||||
from .handlers import chat, command
|
||||
from .. import entities
|
||||
from .. import stage, entities, stagemgr
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
@@ -23,7 +23,7 @@ class Processor(stage.PipelineStage):
|
||||
|
||||
chat_handler: handler.MessageHandler
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
self.cmd_handler = command.CommandHandler(self.ap)
|
||||
self.chat_handler = chat.ChatMessageHandler(self.ap)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import entities, stagemgr, stage
|
||||
from .. import entities, stage
|
||||
from . import algo
|
||||
from .algos import fixedwin
|
||||
from ...core import entities as core_entities
|
||||
@@ -18,7 +18,7 @@ class RateLimit(stage.PipelineStage):
|
||||
|
||||
algo: algo.ReteLimitAlgo
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
|
||||
algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo']
|
||||
|
||||
|
||||
@@ -5,8 +5,10 @@ import asyncio
|
||||
|
||||
|
||||
from ...core import app
|
||||
from ...platform.types import events as platform_events
|
||||
from ...platform.types import message as platform_message
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
@@ -19,8 +21,8 @@ class SendResponseBackStage(stage.PipelineStage):
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理
|
||||
"""
|
||||
|
||||
random_range = (self.ap.platform_cfg.data['force-delay']['min'], self.ap.platform_cfg.data['force-delay']['max'])
|
||||
|
||||
random_range = (query.pipeline_config['output']['force-delay']['min'], query.pipeline_config['output']['force-delay']['max'])
|
||||
|
||||
random_delay = random.uniform(*random_range)
|
||||
|
||||
@@ -31,10 +33,20 @@ class SendResponseBackStage(stage.PipelineStage):
|
||||
|
||||
await asyncio.sleep(random_delay)
|
||||
|
||||
await self.ap.platform_mgr.send(
|
||||
query.message_event,
|
||||
query.resp_message_chain[-1],
|
||||
adapter=query.adapter
|
||||
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage):
|
||||
query.resp_message_chain[-1].insert(
|
||||
0,
|
||||
platform_message.At(
|
||||
query.message_event.sender.id
|
||||
)
|
||||
)
|
||||
|
||||
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
|
||||
|
||||
await query.adapter.reply_message(
|
||||
message_source=query.message_event,
|
||||
message=query.resp_message_chain[-1],
|
||||
quote_origin=quote_origin
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
|
||||
@@ -5,7 +5,7 @@ from ...core import app
|
||||
from . import entities as rule_entities, rule
|
||||
from .rules import atbot, prefix, regexp, random
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
@@ -20,7 +20,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
rule_matchers: list[rule.GroupRespondRule]
|
||||
"""检查器实例"""
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
"""初始化检查器
|
||||
"""
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ class PipelineStage(metaclass=abc.ABCMeta):
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
"""初始化
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..core import app
|
||||
from . import stage
|
||||
from .resprule import resprule
|
||||
from .bansess import bansess
|
||||
from .cntfilter import cntfilter
|
||||
from .process import process
|
||||
from .longtext import longtext
|
||||
from .respback import respback
|
||||
from .wrapper import wrapper
|
||||
from .preproc import preproc
|
||||
from .ratelimit import ratelimit
|
||||
from .msgtrun import msgtrun
|
||||
|
||||
|
||||
# 请求处理阶段顺序
|
||||
stage_order = [
|
||||
"GroupRespondRuleCheckStage", # 群响应规则检查
|
||||
"BanSessionCheckStage", # 封禁会话检查
|
||||
"PreContentFilterStage", # 内容过滤前置阶段
|
||||
"PreProcessor", # 预处理器
|
||||
"ConversationMessageTruncator", # 会话消息截断器
|
||||
"RequireRateLimitOccupancy", # 请求速率限制占用
|
||||
"MessageProcessor", # 处理器
|
||||
"ReleaseRateLimitOccupancy", # 释放速率限制占用
|
||||
"PostContentFilterStage", # 内容过滤后置阶段
|
||||
"ResponseWrapper", # 响应包装器
|
||||
"LongTextProcessStage", # 长文本处理
|
||||
"SendResponseBackStage", # 发送响应
|
||||
]
|
||||
|
||||
|
||||
class StageInstContainer():
|
||||
"""阶段实例容器
|
||||
"""
|
||||
|
||||
inst_name: str
|
||||
|
||||
inst: stage.PipelineStage
|
||||
|
||||
def __init__(self, inst_name: str, inst: stage.PipelineStage):
|
||||
self.inst_name = inst_name
|
||||
self.inst = inst
|
||||
|
||||
|
||||
class StageManager:
|
||||
ap: app.Application
|
||||
|
||||
stage_containers: list[StageInstContainer]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
self.stage_containers = []
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化
|
||||
"""
|
||||
|
||||
for name, cls in stage.preregistered_stages.items():
|
||||
self.stage_containers.append(StageInstContainer(
|
||||
inst_name=name,
|
||||
inst=cls(self.ap)
|
||||
))
|
||||
|
||||
for stage_containers in self.stage_containers:
|
||||
await stage_containers.inst.initialize()
|
||||
|
||||
# 按照 stage_order 排序
|
||||
self.stage_containers.sort(key=lambda x: stage_order.index(x.inst_name))
|
||||
@@ -5,7 +5,7 @@ import typing
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from .. import entities
|
||||
from .. import stage, entities, stagemgr
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from ...plugin import events
|
||||
@@ -22,7 +22,7 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
- resp_message_chain
|
||||
"""
|
||||
|
||||
async def initialize(self):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
pass
|
||||
|
||||
async def process(
|
||||
@@ -110,7 +110,7 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
|
||||
query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
|
||||
|
||||
if self.ap.platform_cfg.data['track-function-calls']:
|
||||
if query.pipeline_config['output']['misc']['track-function-calls']:
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
|
||||
Reference in New Issue
Block a user