feat: preliminarily implement pipeline invoking

This commit is contained in:
Junyan Qin
2025-03-29 17:50:45 +08:00
parent d01eadc70f
commit 9f15ab5000
57 changed files with 384 additions and 421 deletions
+2 -2
View File
@@ -1,7 +1,7 @@
from __future__ import annotations
import re
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@@ -13,7 +13,7 @@ class BanSessionCheckStage(stage.PipelineStage):
仅检查query中群号或个人号是否在访问控制列表中。
"""
async def initialize(self):
async def initialize(self, pipeline_config: dict):
pass
async def process(
+2 -2
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
from ...core import app
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from . import filter as filter_model, entities as filter_entities
@@ -35,7 +35,7 @@ class ContentFilterStage(stage.PipelineStage):
self.filter_chain = []
super().__init__(ap)
async def initialize(self):
async def initialize(self, pipeline_config: dict):
filters_required = [
"content-ignore",
+7 -3
View File
@@ -54,9 +54,13 @@ class Controller:
async def _process_query(selected_query: entities.Query):
async with self.semaphore: # 总并发上限
# find pipeline
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(selected_query.pipeline_uuid)
if pipeline:
await pipeline.run(selected_query)
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
# Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected.
bot = await self.ap.platform_mgr.get_bot_by_uuid(selected_query.bot_uuid)
if bot:
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(bot.bot_entity.use_pipeline_uuid)
if pipeline:
await pipeline.run(selected_query)
async with self.ap.query_pool:
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
+6 -6
View File
@@ -7,7 +7,7 @@ from PIL import Image, ImageDraw, ImageFont
from ...core import app
from . import strategy
from .strategies import image, forward
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from ...platform.types import message as platform_message
@@ -23,8 +23,8 @@ class LongTextProcessStage(stage.PipelineStage):
strategy_impl: strategy.LongTextStrategy
async def initialize(self):
config = self.ap.platform_cfg.data['long-text-process']
async def initialize(self, pipeline_config: dict):
config = pipeline_config['output']['long-text-processing']
if config['strategy'] == 'image':
use_font = config['font-path']
try:
@@ -42,12 +42,12 @@ class LongTextProcessStage(stage.PipelineStage):
else:
self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。")
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
pipeline_config['output']['long-text-processing']['strategy'] = "forward"
except:
traceback.print_exc()
self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。".format(use_font))
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
pipeline_config['output']['long-text-processing']['strategy'] = "forward"
for strategy_cls in strategy.preregistered_strategies:
if strategy_cls.name == config['strategy']:
@@ -69,7 +69,7 @@ class LongTextProcessStage(stage.PipelineStage):
if contains_non_plain:
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']:
elif len(str(query.resp_message_chain[-1])) > query.pipeline_config['output']['long-text-processing']['threshold']:
query.resp_message_chain[-1] = platform_message.MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
return entities.StageProcessResult(
+10 -6
View File
@@ -8,6 +8,7 @@ import re
from PIL import Image, ImageDraw, ImageFont
import functools
from ....platform.types import message as platform_message
from .. import strategy as strategy_model
@@ -17,15 +18,18 @@ from ....core import entities as core_entities
@strategy_model.strategy_class("image")
class Text2ImageStrategy(strategy_model.LongTextStrategy):
text_render_font: ImageFont.FreeTypeFont
async def initialize(self):
self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8")
pass
@functools.lru_cache(maxsize=16)
def get_font(self, query: core_entities.Query):
return ImageFont.truetype(query.pipeline_config['output']['long-text-processing']['font-path'], 32, encoding="utf-8")
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
img_path = self.text_to_image(
text_str=message,
save_as='temp/{}.png'.format(int(time.time()))
save_as='temp/{}.png'.format(int(time.time())),
query=query
)
compressed_path, size = self.compress_image(
@@ -127,7 +131,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
return outfile, self.get_size(outfile)
def text_to_image(self, text_str: str, save_as="temp.png", width=800):
def text_to_image(self, text_str: str, save_as="temp.png", width=800, query: core_entities.Query = None):
text_str = text_str.replace("\t", " ")
@@ -142,7 +146,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width))
for line in lines:
# 如果长了就分割
line_width = self.text_render_font.getlength(line)
line_width = self.get_font(query).getlength(line)
self.ap.logger.debug("line_width: {}".format(line_width))
if line_width < text_width:
final_lines.append(line)
+2 -2
View File
@@ -1,6 +1,6 @@
from __future__ import annotations
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from . import truncator
from .truncators import round
@@ -14,7 +14,7 @@ class ConversationMessageTruncator(stage.PipelineStage):
"""
trun: truncator.Truncator
async def initialize(self):
async def initialize(self, pipeline_config: dict):
use_method = self.ap.pipeline_cfg.data['msg-truncate']['method']
for trun in truncator.preregistered_truncators:
+45 -9
View File
@@ -8,10 +8,35 @@ import sqlalchemy
from ..core import app, entities
from . import entities as pipeline_entities
from ..entity.persistence import pipeline as persistence_pipeline
from . import stagemgr, stage
from . import stage
from ..platform.types import message as platform_message, events as platform_events
from ..plugin import events
from .resprule import resprule
from .bansess import bansess
from .cntfilter import cntfilter
from .process import process
from .longtext import longtext
from .respback import respback
from .wrapper import wrapper
from .preproc import preproc
from .ratelimit import ratelimit
from .msgtrun import msgtrun
class StageInstContainer():
"""阶段实例容器
"""
inst_name: str
inst: stage.PipelineStage
def __init__(self, inst_name: str, inst: stage.PipelineStage):
self.inst_name = inst_name
self.inst = inst
class RuntimePipeline:
"""运行时流水线"""
@@ -20,10 +45,10 @@ class RuntimePipeline:
pipeline_entity: persistence_pipeline.LegacyPipeline
"""流水线实体"""
stage_containers: list[stagemgr.StageInstContainer]
stage_containers: list[StageInstContainer]
"""阶段实例容器"""
def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[stagemgr.StageInstContainer]):
def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[StageInstContainer]):
self.ap = ap
self.pipeline_entity = pipeline_entity
self.stage_containers = stage_containers
@@ -47,10 +72,18 @@ class RuntimePipeline:
*result.user_notice
)
await self.ap.platform_mgr.send(
query.message_event,
result.user_notice,
query.adapter
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage):
result.user_notice.insert(
0,
platform_message.At(
query.message_event.sender.id
)
)
await query.adapter.reply_message(
message_source=query.message_event,
message=result.user_notice,
quote_origin=query.pipeline_config['output']['misc']['quote-origin']
)
if result.debug_notice:
self.ap.logger.debug(result.debug_notice)
@@ -195,12 +228,15 @@ class PipelineManager:
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
# initialize stage containers according to pipeline_entity.stages
stage_containers = []
stage_containers: list[StageInstContainer] = []
for stage_name in pipeline_entity.stages:
stage_containers.append(stagemgr.StageInstContainer(
stage_containers.append(StageInstContainer(
inst_name=stage_name,
inst=self.stage_dict[stage_name](self.ap)
))
for stage_container in stage_containers:
await stage_container.inst.initialize(pipeline_entity.config)
runtime_pipeline = RuntimePipeline(self.ap, pipeline_entity, stage_containers)
self.pipelines.append(runtime_pipeline)
+2 -2
View File
@@ -28,23 +28,23 @@ class QueryPool:
async def add_query(
self,
bot_uuid: str,
launcher_type: entities.LauncherTypes,
launcher_id: typing.Union[int, str],
sender_id: typing.Union[int, str],
message_event: platform_events.MessageEvent,
message_chain: platform_message.MessageChain,
adapter: msadapter.MessagePlatformAdapter,
pipeline_uuid: str
) -> entities.Query:
async with self.condition:
query = entities.Query(
bot_uuid=bot_uuid,
query_id=self.query_id_counter,
launcher_type=launcher_type,
launcher_id=launcher_id,
sender_id=sender_id,
message_event=message_event,
message_chain=message_chain,
pipeline_uuid=pipeline_uuid,
resp_messages=[],
resp_message_chain=[],
adapter=adapter
+8 -7
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
import datetime
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...provider import entities as llm_entities
from ...plugin import events
@@ -33,16 +33,16 @@ class PreProcessor(stage.PipelineStage):
"""
session = await self.ap.sess_mgr.get_session(query)
conversation = await self.ap.sess_mgr.get_conversation(session)
conversation = await self.ap.sess_mgr.get_conversation(query, session)
# 设置query
query.session = session
query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy()
query.use_model = conversation.use_model
query.use_llm_model = conversation.use_llm_model
query.use_funcs = conversation.use_funcs if query.use_model.tool_call_supported else None
query.use_funcs = conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('tool_call') else None
query.variables = {
"session_id": f"{query.session.launcher_type.value}_{query.session.launcher_id}",
@@ -50,8 +50,9 @@ class PreProcessor(stage.PipelineStage):
"msg_create_time": int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp()),
}
# 检查vision是否启用,没启用就删除所有图片
if not self.ap.provider_cfg.data['enable-vision'] or (self.ap.provider_cfg.data['runner'] == 'local-agent' and not query.use_model.vision_supported):
# Check if this model supports vision, if not, remove all images
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
if query.pipeline_config['ai']['runner']['runner'] == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'):
for msg in query.messages:
if isinstance(msg.content, list):
for me in msg.content:
@@ -69,7 +70,7 @@ class PreProcessor(stage.PipelineStage):
)
plain_text += me.text
elif isinstance(me, platform_message.Image):
if self.ap.provider_cfg.data['enable-vision'] and (self.ap.provider_cfg.data['runner'] != 'local-agent' or query.use_model.vision_supported):
if query.pipeline_config['ai']['runner']['runner'] != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__('vision'):
if me.base64 is not None:
content_list.append(
llm_entities.ContentElement.from_image_base64(me.base64)
+12 -9
View File
@@ -9,7 +9,9 @@ import json
from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....provider import entities as llm_entities, runnermgr
from ....provider import entities as llm_entities
from ....provider import runner as runner_module
from ....provider.runners import localagent, difysvapi, dashscopeapi
from ....plugin import events
from ....platform.types import message as platform_message
@@ -56,12 +58,6 @@ class ChatMessageHandler(handler.MessageHandler):
)
else:
if not self.ap.provider_cfg.data['enable-chat']:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
)
if event_ctx.event.alter is not None:
# if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
query.user_message.content = event_ctx.event.alter
@@ -72,7 +68,12 @@ class ChatMessageHandler(handler.MessageHandler):
try:
runner = self.ap.runner_mgr.get_runner()
for r in runner_module.preregistered_runners:
if r.name == query.pipeline_config["ai"]["runner"]["runner"]:
runner = r(self.ap, query.pipeline_config)
break
else:
raise ValueError(f"未找到请求运行器: {query.pipeline_config['ai']['runner']['runner']}")
async for result in runner.run(query):
query.resp_messages.append(result)
@@ -93,10 +94,12 @@ class ChatMessageHandler(handler.MessageHandler):
self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}')
hide_exception_info = query.pipeline_config['output']['misc']['hide-exception']
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice='请求失败' if self.ap.platform_cfg.data['hide-exception-info'] else f'{e}',
user_notice='请求失败' if hide_exception_info else f'{e}',
error_notice=f'{e}',
debug_notice=traceback.format_exc()
)
+2 -2
View File
@@ -4,7 +4,7 @@ from ...core import app, entities as core_entities
from . import handler
from .handlers import chat, command
from .. import entities
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@@ -23,7 +23,7 @@ class Processor(stage.PipelineStage):
chat_handler: handler.MessageHandler
async def initialize(self):
async def initialize(self, pipeline_config: dict):
self.cmd_handler = command.CommandHandler(self.ap)
self.chat_handler = chat.ChatMessageHandler(self.ap)
+2 -2
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
import typing
from .. import entities, stagemgr, stage
from .. import entities, stage
from . import algo
from .algos import fixedwin
from ...core import entities as core_entities
@@ -18,7 +18,7 @@ class RateLimit(stage.PipelineStage):
algo: algo.ReteLimitAlgo
async def initialize(self):
async def initialize(self, pipeline_config: dict):
algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo']
+19 -7
View File
@@ -5,8 +5,10 @@ import asyncio
from ...core import app
from ...platform.types import events as platform_events
from ...platform.types import message as platform_message
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@@ -19,8 +21,8 @@ class SendResponseBackStage(stage.PipelineStage):
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理
"""
random_range = (self.ap.platform_cfg.data['force-delay']['min'], self.ap.platform_cfg.data['force-delay']['max'])
random_range = (query.pipeline_config['output']['force-delay']['min'], query.pipeline_config['output']['force-delay']['max'])
random_delay = random.uniform(*random_range)
@@ -31,10 +33,20 @@ class SendResponseBackStage(stage.PipelineStage):
await asyncio.sleep(random_delay)
await self.ap.platform_mgr.send(
query.message_event,
query.resp_message_chain[-1],
adapter=query.adapter
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage):
query.resp_message_chain[-1].insert(
0,
platform_message.At(
query.message_event.sender.id
)
)
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
await query.adapter.reply_message(
message_source=query.message_event,
message=query.resp_message_chain[-1],
quote_origin=quote_origin
)
return entities.StageProcessResult(
+2 -2
View File
@@ -5,7 +5,7 @@ from ...core import app
from . import entities as rule_entities, rule
from .rules import atbot, prefix, regexp, random
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@@ -20,7 +20,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
rule_matchers: list[rule.GroupRespondRule]
"""检查器实例"""
async def initialize(self):
async def initialize(self, pipeline_config: dict):
"""初始化检查器
"""
+1 -1
View File
@@ -28,7 +28,7 @@ class PipelineStage(metaclass=abc.ABCMeta):
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
async def initialize(self, pipeline_config: dict):
"""初始化
"""
pass
-71
View File
@@ -1,71 +0,0 @@
from __future__ import annotations
from ..core import app
from . import stage
from .resprule import resprule
from .bansess import bansess
from .cntfilter import cntfilter
from .process import process
from .longtext import longtext
from .respback import respback
from .wrapper import wrapper
from .preproc import preproc
from .ratelimit import ratelimit
from .msgtrun import msgtrun
# 请求处理阶段顺序
stage_order = [
"GroupRespondRuleCheckStage", # 群响应规则检查
"BanSessionCheckStage", # 封禁会话检查
"PreContentFilterStage", # 内容过滤前置阶段
"PreProcessor", # 预处理器
"ConversationMessageTruncator", # 会话消息截断器
"RequireRateLimitOccupancy", # 请求速率限制占用
"MessageProcessor", # 处理器
"ReleaseRateLimitOccupancy", # 释放速率限制占用
"PostContentFilterStage", # 内容过滤后置阶段
"ResponseWrapper", # 响应包装器
"LongTextProcessStage", # 长文本处理
"SendResponseBackStage", # 发送响应
]
class StageInstContainer():
"""阶段实例容器
"""
inst_name: str
inst: stage.PipelineStage
def __init__(self, inst_name: str, inst: stage.PipelineStage):
self.inst_name = inst_name
self.inst = inst
class StageManager:
ap: app.Application
stage_containers: list[StageInstContainer]
def __init__(self, ap: app.Application):
self.ap = ap
self.stage_containers = []
async def initialize(self):
"""初始化
"""
for name, cls in stage.preregistered_stages.items():
self.stage_containers.append(StageInstContainer(
inst_name=name,
inst=cls(self.ap)
))
for stage_containers in self.stage_containers:
await stage_containers.inst.initialize()
# 按照 stage_order 排序
self.stage_containers.sort(key=lambda x: stage_order.index(x.inst_name))
+3 -3
View File
@@ -5,7 +5,7 @@ import typing
from ...core import app, entities as core_entities
from .. import entities
from .. import stage, entities, stagemgr
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from ...plugin import events
@@ -22,7 +22,7 @@ class ResponseWrapper(stage.PipelineStage):
- resp_message_chain
"""
async def initialize(self):
async def initialize(self, pipeline_config: dict):
pass
async def process(
@@ -110,7 +110,7 @@ class ResponseWrapper(stage.PipelineStage):
query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
if self.ap.platform_cfg.data['track-function-calls']:
if query.pipeline_config['output']['misc']['track-function-calls']:
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(