chore: adjust dir structure

This commit is contained in:
Junyan Qin
2025-11-16 16:28:04 +08:00
parent c5aa5be4d8
commit 75edeb7a01
451 changed files with 299 additions and 525 deletions

View File

View File

@@ -0,0 +1,51 @@
from __future__ import annotations
from .. import stage, entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('BanSessionCheckStage')
class BanSessionCheckStage(stage.PipelineStage):
"""Access control processing stage
Only check if the group or personal number in the query is in the access control list.
"""
async def initialize(self, pipeline_config: dict):
pass
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
found = False
mode = query.pipeline_config['trigger']['access-control']['mode']
sess_list = query.pipeline_config['trigger']['access-control'][mode]
if (query.launcher_type.value == 'group' and 'group_*' in sess_list) or (
query.launcher_type.value == 'person' and 'person_*' in sess_list
):
found = True
else:
for sess in sess_list:
if sess == f'{query.launcher_type.value}_{query.launcher_id}':
found = True
break
# 使用 *_id 来表示加白/拉黑某用户的私聊和群聊场景
if sess.startswith('*_') and (sess[2:] == query.launcher_id or sess[2:] == query.sender_id):
found = True
break
ctn = False
if mode == 'whitelist':
ctn = found
else:
ctn = not found
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE if ctn else entities.ResultType.INTERRUPT,
new_query=query,
console_notice=f'Ignore message according to access control: {query.launcher_type.value}_{query.launcher_id}'
if not ctn
else '',
)

View File

@@ -0,0 +1,154 @@
from __future__ import annotations
from ...core import app
from .. import stage, entities
from . import filter as filter_model, entities as filter_entities
from langbot_plugin.api.entities.builtin.provider import message as provider_message
import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import filters
importutil.import_modules_in_pkg(filters)
@stage.stage_class('PostContentFilterStage')
@stage.stage_class('PreContentFilterStage')
class ContentFilterStage(stage.PipelineStage):
"""内容过滤阶段
前置:
检查消息是否符合规则,不符合则拦截。
改写:
message_chain
后置:
检查AI回复消息是否符合规则可能进行改写不符合则拦截。
改写:
query.resp_messages
"""
filter_chain: list[filter_model.ContentFilter]
def __init__(self, ap: app.Application):
self.filter_chain = []
super().__init__(ap)
async def initialize(self, pipeline_config: dict):
filters_required = [
'content-ignore',
]
if pipeline_config['safety']['content-filter']['check-sensitive-words']:
filters_required.append('ban-word-filter')
# TODO revert it
# if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
# filters_required.append("baidu-cloud-examine")
for filter in filter_model.preregistered_filters:
if filter.name in filters_required:
self.filter_chain.append(filter(self.ap))
for filter in self.filter_chain:
await filter.initialize()
async def _pre_process(
self,
message: str,
query: pipeline_query.Query,
) -> entities.StageProcessResult:
"""请求llm前处理消息
只要有一个不通过就不放行,只放行 PASS 的消息
"""
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
if not message.strip():
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
for filter in self.filter_chain:
if filter_entities.EnableStage.PRE in filter.enable_stages:
result = await filter.process(query, message)
if result.level in [
filter_entities.ResultLevel.BLOCK,
filter_entities.ResultLevel.MASKED,
]:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice=result.user_notice,
console_notice=result.console_notice,
)
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
message = result.replacement
query.message_chain = platform_message.MessageChain([platform_message.Plain(text=message)])
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
async def _post_process(
self,
message: str,
query: pipeline_query.Query,
) -> entities.StageProcessResult:
"""请求llm后处理响应
只要是 PASS 或者 MASKED 的就通过此 filter将其 replacement 设置为message进入下一个 filter
"""
if query.pipeline_config['safety']['content-filter']['scope'] == 'income-msg':
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
message = message.strip()
for filter in self.filter_chain:
if filter_entities.EnableStage.POST in filter.enable_stages:
result = await filter.process(query, message)
if result.level == filter_entities.ResultLevel.BLOCK:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice=result.user_notice,
console_notice=result.console_notice,
)
elif result.level in [
filter_entities.ResultLevel.PASS,
filter_entities.ResultLevel.MASKED,
]:
message = result.replacement
query.resp_messages[-1].content = message
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理"""
if stage_inst_name == 'PreContentFilterStage':
contain_non_text = False
text_components = [platform_message.Plain, platform_message.Source]
for me in query.message_chain:
if type(me) not in text_components:
contain_non_text = True
break
if contain_non_text:
self.ap.logger.debug('消息中包含非文本消息,跳过内容过滤器检查。')
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
return await self._pre_process(str(query.message_chain).strip(), query)
elif stage_inst_name == 'PostContentFilterStage':
# 仅处理 query.resp_messages[-1].content 是 str 的情况
if isinstance(query.resp_messages[-1], provider_message.Message) and isinstance(
query.resp_messages[-1].content, str
):
return await self._post_process(query.resp_messages[-1].content, query)
else:
self.ap.logger.debug(
'resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。'
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')

View File

@@ -0,0 +1,74 @@
import enum
import pydantic
class ResultLevel(enum.Enum):
"""结果等级"""
PASS = enum.auto()
"""通过"""
WARN = enum.auto()
"""警告"""
MASKED = enum.auto()
"""已掩去"""
BLOCK = enum.auto()
"""阻止"""
class EnableStage(enum.Enum):
"""启用阶段"""
PRE = enum.auto()
"""预处理"""
POST = enum.auto()
"""后处理"""
class FilterResult(pydantic.BaseModel):
level: ResultLevel
"""结果等级
对于前置处理阶段,只要有任意一个返回 非PASS 的内容过滤器结果,就会中断处理。
对于后置处理阶段,当且内容过滤器返回 BLOCK 时,会中断处理。
"""
replacement: str
"""替换后的文本消息
内容过滤器可以进行一些遮掩处理,然后把遮掩后的消息返回。
若没有修改内容,也需要返回原消息。
"""
user_notice: str
"""不通过时,若此值不为空,将对用户提示消息"""
console_notice: str
"""不通过时,若此值不为空,将在控制台提示消息"""
class ManagerResultLevel(enum.Enum):
"""处理器结果等级"""
CONTINUE = enum.auto()
"""继续"""
INTERRUPT = enum.auto()
"""中断"""
class FilterManagerResult(pydantic.BaseModel):
level: ManagerResultLevel
replacement: str
"""替换后的消息"""
user_notice: str
"""用户提示消息"""
console_notice: str
"""控制台提示消息"""

View File

@@ -0,0 +1,76 @@
# 内容过滤器的抽象类
from __future__ import annotations
import abc
import typing
from ...core import app
from . import entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_filters: list[typing.Type[ContentFilter]] = []
def filter_class(
name: str,
) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]:
"""Content filter class decorator
Args:
name (str): Filter name
Returns:
typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: Decorator
"""
def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]:
assert issubclass(cls, ContentFilter)
cls.name = name
preregistered_filters.append(cls)
return cls
return decorator
class ContentFilter(metaclass=abc.ABCMeta):
"""Content filter abstract class"""
name: str
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
@property
def enable_stages(self):
"""Enabled stages
Default is the two stages before and after the message request to AI.
entity.EnableStage.PRE: Before message request to AI, the content to check is the user's input message.
entity.EnableStage.POST: After message request to AI, the content to check is the AI's reply message.
"""
return [entities.EnableStage.PRE, entities.EnableStage.POST]
async def initialize(self):
"""Initialize filter"""
pass
@abc.abstractmethod
async def process(self, query: pipeline_query.Query, message: str = None, image_url=None) -> entities.FilterResult:
"""处理消息
It is divided into two stages, depending on the value of enable_stages.
For content filters, you do not need to consider the stage of the message, you only need to check the message content.
Args:
message (str): Content to check
image_url (str): URL of the image to check
Returns:
entities.FilterResult: Filter result, please refer to the documentation of entities.FilterResult class
"""
raise NotImplementedError

View File

@@ -0,0 +1,64 @@
from __future__ import annotations
import aiohttp
from .. import entities
from .. import filter as filter_model
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}'
BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
@filter_model.filter_class('baidu-cloud-examine')
class BaiduCloudExamine(filter_model.ContentFilter):
"""百度云内容审核"""
async def _get_token(self) -> str:
async with aiohttp.ClientSession() as session:
async with session.post(
BAIDU_EXAMINE_TOKEN_URL,
params={
'grant_type': 'client_credentials',
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'],
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret'],
},
) as resp:
return (await resp.json())['access_token']
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
async with aiohttp.ClientSession() as session:
async with session.post(
BAIDU_EXAMINE_URL.format(await self._get_token()),
headers={
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json',
},
data=f'text={message}'.encode('utf-8'),
) as resp:
result = await resp.json()
if 'error_code' in result:
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement=message,
user_notice='',
console_notice=f'百度云判定出错,错误信息:{result["error_msg"]}',
)
else:
conclusion = result['conclusion']
if conclusion in ('合规'):
return entities.FilterResult(
level=entities.ResultLevel.PASS,
replacement=message,
user_notice='',
console_notice=f'百度云判定结果:{conclusion}',
)
else:
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement=message,
user_notice='消息中存在不合适的内容, 请修改',
console_notice=f'百度云判定结果:{conclusion}',
)

View File

@@ -0,0 +1,39 @@
from __future__ import annotations
import re
from .. import filter as filter_model
from .. import entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@filter_model.filter_class('ban-word-filter')
class BanWordFilter(filter_model.ContentFilter):
"""Filter content"""
async def initialize(self):
pass
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
found = False
for word in self.ap.sensitive_meta.data['words']:
match = re.findall(word, message)
if len(match) > 0:
found = True
for i in range(len(match)):
if self.ap.sensitive_meta.data['mask_word'] == '':
message = message.replace(
match[i],
self.ap.sensitive_meta.data['mask'] * len(match[i]),
)
else:
message = message.replace(match[i], self.ap.sensitive_meta.data['mask_word'])
return entities.FilterResult(
level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS,
replacement=message,
user_notice='消息中存在不合适的内容, 请修改' if found else '',
console_notice='',
)

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
import re
from .. import entities
from .. import filter as filter_model
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@filter_model.filter_class('content-ignore')
class ContentIgnore(filter_model.ContentFilter):
"""Ignore message according to content"""
@property
def enable_stages(self):
return [
entities.EnableStage.PRE,
]
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
if message.startswith(rule):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement='',
user_notice='',
console_notice='Ignore message according to prefix rule in ignore_rules',
)
if 'regexp' in query.pipeline_config['trigger']['ignore-rules']:
for rule in query.pipeline_config['trigger']['ignore-rules']['regexp']:
if re.search(rule, message):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement='',
user_notice='',
console_notice='Ignore message according to regexp rule in ignore_rules',
)
return entities.FilterResult(
level=entities.ResultLevel.PASS,
replacement=message,
user_notice='',
console_notice='',
)

View File

@@ -0,0 +1,86 @@
from __future__ import annotations
import asyncio
import traceback
from ..core import app
from ..core import entities as core_entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class Controller:
"""总控制器"""
ap: app.Application
semaphore: asyncio.Semaphore = None
"""请求并发控制信号量"""
def __init__(self, ap: app.Application):
self.ap = ap
self.semaphore = asyncio.Semaphore(self.ap.instance_config.data['concurrency']['pipeline'])
async def consumer(self):
"""事件处理循环"""
try:
while True:
selected_query: pipeline_query.Query = None
# 取请求
async with self.ap.query_pool:
queries: list[pipeline_query.Query] = self.ap.query_pool.queries
for query in queries:
session = await self.ap.sess_mgr.get_session(query)
self.ap.logger.debug(f'Checking query {query} session {session}')
if not session._semaphore.locked():
selected_query = query
await session._semaphore.acquire()
break
if selected_query: # 找到了
queries.remove(selected_query)
else: # 没找到 说明:没有请求 或者 所有query对应的session都已达到并发上限
await self.ap.query_pool.condition.wait()
continue
if selected_query:
async def _process_query(selected_query: pipeline_query.Query):
async with self.semaphore: # 总并发上限
# find pipeline
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
# Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected.
pipeline_uuid = selected_query.pipeline_uuid
if pipeline_uuid:
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(pipeline_uuid)
if pipeline:
await pipeline.run(selected_query)
async with self.ap.query_pool:
(await self.ap.sess_mgr.get_session(selected_query))._semaphore.release()
# 通知其他协程,有新的请求可以处理了
self.ap.query_pool.condition.notify_all()
self.ap.task_mgr.create_task(
_process_query(selected_query),
kind='query',
name=f'query-{selected_query.query_id}',
scopes=[
core_entities.LifecycleControlScope.APPLICATION,
core_entities.LifecycleControlScope.PLATFORM,
],
)
except Exception as e:
# traceback.print_exc()
self.ap.logger.error(f'控制器循环出错: {e}')
self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
async def run(self):
"""运行控制器"""
await self.consumer()

View File

@@ -0,0 +1,40 @@
from __future__ import annotations
import enum
import typing
import pydantic
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.message as platform_message
class ResultType(enum.Enum):
CONTINUE = enum.auto()
"""继续流水线"""
INTERRUPT = enum.auto()
"""中断流水线"""
class StageProcessResult(pydantic.BaseModel):
result_type: ResultType
new_query: pipeline_query.Query
user_notice: typing.Optional[
typing.Union[
str,
list[platform_message.MessageComponent],
platform_message.MessageChain,
None,
]
] = []
"""只要设置了就会发送给用户"""
console_notice: typing.Optional[str] = ''
"""只要设置了就会输出到控制台"""
debug_notice: typing.Optional[str] = ''
error_notice: typing.Optional[str] = ''

View File

@@ -0,0 +1,97 @@
from __future__ import annotations
import os
import traceback
from . import strategy
from .. import stage, entities
import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import strategies
importutil.import_modules_in_pkg(strategies)
@stage.stage_class('LongTextProcessStage')
class LongTextProcessStage(stage.PipelineStage):
"""Long message processing stage
Rewrite:
- resp_message_chain
"""
strategy_impl: strategy.LongTextStrategy | None
async def initialize(self, pipeline_config: dict):
config = pipeline_config['output']['long-text-processing']
if config['strategy'] == 'none':
self.strategy_impl = None
return
if config['strategy'] == 'image':
use_font = config['font-path']
try:
# 检查是否存在
if not os.path.exists(use_font):
# 若是windows系统使用微软雅黑
if os.name == 'nt':
use_font = 'C:/Windows/Fonts/msyh.ttc'
if not os.path.exists(use_font):
self.ap.logger.warn(
'Font file not found, and Windows system font cannot be used, switch to forward message component to send long messages, you can adjust the related settings in the configuration file.'
)
config['blob_message_strategy'] = 'forward'
else:
self.ap.logger.info('Using Windows system font: ' + use_font)
config['font-path'] = use_font
else:
self.ap.logger.warn(
'Font file not found, and system font cannot be used, switch to forward message component to send long messages, you can adjust the related settings in the configuration file.'
)
pipeline_config['output']['long-text-processing']['strategy'] = 'forward'
except Exception:
traceback.print_exc()
self.ap.logger.error(
'Failed to load font file ({}), switch to forward message component to send long messages, you can adjust the related settings in the configuration file.'.format(
use_font
)
)
pipeline_config['output']['long-text-processing']['strategy'] = 'forward'
for strategy_cls in strategy.preregistered_strategies:
if strategy_cls.name == config['strategy']:
self.strategy_impl = strategy_cls(self.ap)
break
else:
raise ValueError(f'Long message processing strategy not found: {config["strategy"]}')
await self.strategy_impl.initialize()
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
if self.strategy_impl is None:
self.ap.logger.debug('Long message processing strategy is not set, skip long message processing.')
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
# 检查是否包含非 Plain 组件
contains_non_plain = False
for msg in query.resp_message_chain[-1]:
if not isinstance(msg, platform_message.Plain):
contains_non_plain = True
break
if contains_non_plain:
self.ap.logger.debug('Message contains non-Plain components, skip long message processing.')
elif (
len(str(query.resp_message_chain[-1]))
> query.pipeline_config['output']['long-text-processing']['threshold']
):
query.resp_message_chain[-1] = platform_message.MessageChain(
await self.strategy_impl.process(str(query.resp_message_chain[-1]), query)
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -0,0 +1,35 @@
# 转发消息组件
from __future__ import annotations
from .. import strategy as strategy_model
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.message as platform_message
ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay
Forward = platform_message.Forward
@strategy_model.strategy_class('forward')
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
display = ForwardMessageDiaplay(
title='Group chat history',
brief='[Chat history]',
source='Chat history',
preview=['User: ' + message],
summary='View 1 forwarded message',
)
node_list = [
platform_message.ForwardMessageNode(
sender_id=query.adapter.bot_account_id,
sender_name='User',
message_chain=platform_message.MessageChain([platform_message.Plain(text=message)]),
)
]
forward = Forward(display=display, node_list=node_list)
return [forward]

View File

@@ -0,0 +1,211 @@
from __future__ import annotations
import os
import base64
import time
import re
from PIL import Image, ImageDraw, ImageFont
import functools
from .. import strategy as strategy_model
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.message as platform_message
@strategy_model.strategy_class('image')
class Text2ImageStrategy(strategy_model.LongTextStrategy):
async def initialize(self):
pass
@functools.lru_cache(maxsize=16)
def get_font(self, font_path: str):
return ImageFont.truetype(
font_path,
32,
encoding='utf-8',
)
async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
img_path = self.text_to_image(
text_str=message,
save_as='temp/{}.png'.format(int(time.time())),
query=query,
)
compressed_path, size = self.compress_image(img_path, outfile='temp/{}_compressed.png'.format(int(time.time())))
with open(compressed_path, 'rb') as f:
img = f.read()
b64 = base64.b64encode(img)
# 删除图片
os.remove(img_path)
if os.path.exists(compressed_path):
os.remove(compressed_path)
return [
platform_message.Image(
base64=b64.decode('utf-8'),
)
]
def indexNumber(self, path=''):
"""
查找字符串中数字所在串中的位置
:param path:目标字符串
:return:<class 'list'>: <class 'list'>: [['1', 16], ['2', 35], ['1', 51]]
"""
kv = []
nums = []
beforeDatas = re.findall('[\\d]+', path)
for num in beforeDatas:
indexV = []
times = path.count(num)
if times > 1:
if num not in nums:
indexs = re.finditer(num, path)
for index in indexs:
iV = []
i = index.span()[0]
iV.append(num)
iV.append(i)
kv.append(iV)
nums.append(num)
else:
index = path.find(num)
indexV.append(num)
indexV.append(index)
kv.append(indexV)
# 根据数字位置排序
indexSort = []
resultIndex = []
for vi in kv:
indexSort.append(vi[1])
indexSort.sort()
for i in indexSort:
for v in kv:
if i == v[1]:
resultIndex.append(v)
return resultIndex
def get_size(self, file):
# 获取文件大小:KB
size = os.path.getsize(file)
return size / 1024
def get_outfile(self, infile, outfile):
if outfile:
return outfile
dir, suffix = os.path.splitext(infile)
outfile = '{}-out{}'.format(dir, suffix)
return outfile
def compress_image(self, infile, outfile='', kb=100, step=20, quality=90):
"""不改变图片尺寸压缩到指定大小
:param infile: 压缩源文件
:param outfile: 压缩文件保存地址
:param mb: 压缩目标,KB
:param step: 每次调整的压缩比率
:param quality: 初始压缩比率
:return: 压缩文件地址,压缩文件大小
"""
o_size = self.get_size(infile)
if o_size <= kb:
return infile, o_size
outfile = self.get_outfile(infile, outfile)
while o_size > kb:
im = Image.open(infile)
im.save(outfile, quality=quality)
if quality - step < 0:
break
quality -= step
o_size = self.get_size(outfile)
return outfile, self.get_size(outfile)
def text_to_image(
self,
text_str: str,
save_as='temp.png',
width=800,
query: pipeline_query.Query = None,
):
text_str = text_str.replace('\t', ' ')
# 分行
lines = text_str.split('\n')
# 计算并分割
final_lines = []
text_width = width - 80
self.ap.logger.debug('lines: {}, text_width: {}'.format(lines, text_width))
for line in lines:
# 如果长了就分割
line_width = self.get_font(query.pipeline_config['output']['long-text-processing']['font-path']).getlength(
line
)
self.ap.logger.debug('line_width: {}'.format(line_width))
if line_width < text_width:
final_lines.append(line)
continue
else:
rest_text = line
while True:
# 分割最前面的一行
point = int(len(rest_text) * (text_width / line_width))
# 检查断点是否在数字中间
numbers = self.indexNumber(rest_text)
for number in numbers:
if number[1] < point < number[1] + len(number[0]) and number[1] != 0:
point = number[1]
break
final_lines.append(rest_text[:point])
rest_text = rest_text[point:]
line_width = self.get_font(
query.pipeline_config['output']['long-text-processing']['font-path']
).getlength(rest_text)
if line_width < text_width:
final_lines.append(rest_text)
break
else:
continue
# 准备画布
img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255))
draw = ImageDraw.Draw(img, mode='RGBA')
self.ap.logger.debug('正在绘制图片...')
# 绘制正文
line_number = 0
offset_x = 20
offset_y = 30
for final_line in final_lines:
draw.text(
(offset_x, offset_y + 35 * line_number),
final_line,
fill=(0, 0, 0),
font=self.get_font(query.pipeline_config['output']['long-text-processing']['font-path']),
)
# 遍历此行,检查是否有emoji
idx_in_line = 0
for ch in final_line:
# 检查字符占位宽
char_code = ord(ch)
if char_code >= 127:
idx_in_line += 1
else:
idx_in_line += 0.5
line_number += 1
self.ap.logger.debug('正在保存图片...')
img.save(save_as)
return save_as

View File

@@ -0,0 +1,65 @@
from __future__ import annotations
import abc
import typing
from ...core import app
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
def strategy_class(
name: str,
) -> typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]:
"""Long text processing strategy class decorator
Args:
name (str): Strategy name
Returns:
typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]: Decorator
"""
def decorator(cls: typing.Type[LongTextStrategy]) -> typing.Type[LongTextStrategy]:
assert issubclass(cls, LongTextStrategy)
cls.name = name
preregistered_strategies.append(cls)
return cls
return decorator
class LongTextStrategy(metaclass=abc.ABCMeta):
"""Long text processing strategy abstract class"""
name: str
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
"""处理长文本
If the text length exceeds the threshold, this method will be called.
Args:
message (str): Message
query (core_entities.Query): Query object
Returns:
list[platform_message.MessageComponent]: Converted platform message components
"""
return []

View File

@@ -0,0 +1,35 @@
from __future__ import annotations
from .. import stage, entities
from . import truncator
from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import truncators
importutil.import_modules_in_pkg(truncators)
@stage.stage_class('ConversationMessageTruncator')
class ConversationMessageTruncator(stage.PipelineStage):
"""Conversation message truncator
Used to truncate the conversation message chain to adapt to the LLM message length limit.
"""
trun: truncator.Truncator
async def initialize(self, pipeline_config: dict):
use_method = 'round'
for trun in truncator.preregistered_truncators:
if trun.name == use_method:
self.trun = trun(self.ap)
break
else:
raise ValueError(f'Unknown truncator: {use_method}')
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理"""
query = await self.trun.truncate(query)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
import typing
import abc
from ...core import app
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_truncators: list[typing.Type[Truncator]] = []
def truncator_class(
name: str,
) -> typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]:
"""截断器类装饰器
Args:
name (str): 截断器名称
Returns:
typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]: 装饰器
"""
def decorator(cls: typing.Type[Truncator]) -> typing.Type[Truncator]:
assert issubclass(cls, Truncator)
cls.name = name
preregistered_truncators.append(cls)
return cls
return decorator
class Truncator(abc.ABC):
"""消息截断器基类"""
name: str
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
"""截断
一般只需要操作query.messages也可以扩展操作query.prompt, query.user_message。
请勿操作其他字段。
"""
pass

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from .. import truncator
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@truncator.truncator_class('round')
class RoundTruncator(truncator.Truncator):
"""Truncate the conversation message chain to adapt to the LLM message length limit."""
async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
"""截断"""
max_round = query.pipeline_config['ai']['local-agent']['max-round']
temp_messages = []
current_round = 0
# Traverse from back to front
for msg in query.messages[::-1]:
if current_round < max_round:
temp_messages.append(msg)
if msg.role == 'user':
current_round += 1
else:
break
query.messages = temp_messages[::-1]
return query

View File

@@ -0,0 +1,306 @@
from __future__ import annotations
import typing
import traceback
import sqlalchemy
from ..core import app
from . import entities as pipeline_entities
from ..entity.persistence import pipeline as persistence_pipeline
from . import stage
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.events as events
from ..utils import importutil
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import (
resprule,
bansess,
cntfilter,
process,
longtext,
respback,
wrapper,
preproc,
ratelimit,
msgtrun,
)
importutil.import_modules_in_pkgs(
[
resprule,
bansess,
cntfilter,
process,
longtext,
respback,
wrapper,
preproc,
ratelimit,
msgtrun,
]
)
class StageInstContainer:
"""阶段实例容器"""
inst_name: str
inst: stage.PipelineStage
def __init__(self, inst_name: str, inst: stage.PipelineStage):
self.inst_name = inst_name
self.inst = inst
class RuntimePipeline:
"""运行时流水线"""
ap: app.Application
pipeline_entity: persistence_pipeline.LegacyPipeline
"""流水线实体"""
stage_containers: list[StageInstContainer]
"""阶段实例容器"""
bound_plugins: list[str]
"""绑定到此流水线的插件列表格式author/plugin_name"""
bound_mcp_servers: list[str]
"""绑定到此流水线的MCP服务器列表格式uuid"""
def __init__(
self,
ap: app.Application,
pipeline_entity: persistence_pipeline.LegacyPipeline,
stage_containers: list[StageInstContainer],
):
self.ap = ap
self.pipeline_entity = pipeline_entity
self.stage_containers = stage_containers
# Extract bound plugins and MCP servers from extensions_preferences
extensions_prefs = pipeline_entity.extensions_preferences or {}
plugin_list = extensions_prefs.get('plugins', [])
self.bound_plugins = [f'{p["author"]}/{p["name"]}' for p in plugin_list] if plugin_list else []
mcp_server_list = extensions_prefs.get('mcp_servers', [])
self.bound_mcp_servers = mcp_server_list if mcp_server_list else []
async def run(self, query: pipeline_query.Query):
query.pipeline_config = self.pipeline_entity.config
# Store bound plugins and MCP servers in query for filtering
query.variables['_pipeline_bound_plugins'] = self.bound_plugins
query.variables['_pipeline_bound_mcp_servers'] = self.bound_mcp_servers
await self.process_query(query)
async def _check_output(self, query: pipeline_query.Query, result: pipeline_entities.StageProcessResult):
"""检查输出"""
if result.user_notice:
# 处理str类型
if isinstance(result.user_notice, str):
result.user_notice = platform_message.MessageChain([platform_message.Plain(text=result.user_notice)])
elif isinstance(result.user_notice, list):
result.user_notice = platform_message.MessageChain(*result.user_notice)
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
query.message_event, platform_events.GroupMessage
):
result.user_notice.insert(0, platform_message.At(target=query.message_event.sender.id))
if await query.adapter.is_stream_output_supported():
await query.adapter.reply_message_chunk(
message_source=query.message_event,
bot_message=query.resp_messages[-1],
message=result.user_notice,
quote_origin=query.pipeline_config['output']['misc']['quote-origin'],
is_final=[msg.is_final for msg in query.resp_messages][0],
)
else:
await query.adapter.reply_message(
message_source=query.message_event,
message=result.user_notice,
quote_origin=query.pipeline_config['output']['misc']['quote-origin'],
)
if result.debug_notice:
self.ap.logger.debug(result.debug_notice)
if result.console_notice:
self.ap.logger.info(result.console_notice)
if result.error_notice:
self.ap.logger.error(result.error_notice)
async def _execute_from_stage(
self,
stage_index: int,
query: pipeline_query.Query,
):
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
如何看懂这里为什么这么写?
去问 GPT-4:
Q1: 现在有一个责任链其中有多个stagequery对象在其中传递stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None]
如果返回的是生成器需要挨个生成result检查是否result中是否要求继续如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了就继续生成下一个result
调用后续的stage直到该生成器全部生成完。责任链中可能有多个stage会返回生成器
Q2: 不是这样的你可能理解有误。如果我们责任链上有这些Stage
A B C D E F G
如果所有的stage都返回Result且所有Result都要求继续那么执行顺序是
A B C D E F G
现在假设C返回的是AsyncGenerator那么执行顺序是
A B C D E F G C D E F G C D E F G ...
Q3: 但是如果不止一个stage会返回生成器呢
"""
i = stage_index
while i < len(self.stage_containers):
stage_container = self.stage_containers[i]
query.current_stage_name = stage_container.inst_name # 标记到 Query 对象里
result = stage_container.inst.process(query, stage_container.inst_name)
if isinstance(result, typing.Coroutine):
result = await result
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
self.ap.logger.debug(
f'Stage {stage_container.inst_name} processed query {query.query_id} res {result.result_type}'
)
await self._check_output(query, result)
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query.query_id}')
break
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
query = result.new_query
elif isinstance(result, typing.AsyncGenerator): # 生成器
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query.query_id} gen')
async for sub_result in result:
self.ap.logger.debug(
f'Stage {stage_container.inst_name} processed query {query.query_id} res {sub_result.result_type}'
)
await self._check_output(query, sub_result)
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query.query_id}')
break
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
query = sub_result.new_query
await self._execute_from_stage(i + 1, query)
break
i += 1
async def process_query(self, query: pipeline_query.Query):
"""处理请求"""
try:
# Get bound plugins for this pipeline
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
# ======== 触发 MessageReceived 事件 ========
event_type = (
events.PersonMessageReceived
if query.launcher_type == provider_session.LauncherTypes.PERSON
else events.GroupMessageReceived
)
event_obj = event_type(
query=query,
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
message_chain=query.message_chain,
)
event_ctx = await self.ap.plugin_connector.emit_event(event_obj, bound_plugins)
if event_ctx.is_prevented_default():
return
self.ap.logger.debug(f'Processing query {query.query_id}')
await self._execute_from_stage(0, query)
except Exception as e:
inst_name = query.current_stage_name if query.current_stage_name else 'unknown'
self.ap.logger.error(f'Error processing query {query.query_id} stage={inst_name} : {e}')
self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
finally:
self.ap.logger.debug(f'Query {query.query_id} processed')
del self.ap.query_pool.cached_queries[query.query_id]
class PipelineManager:
"""流水线管理器"""
# ====== 4.0 ======
ap: app.Application
pipelines: list[RuntimePipeline]
stage_dict: dict[str, type[stage.PipelineStage]]
def __init__(self, ap: app.Application):
self.ap = ap
self.pipelines = []
async def initialize(self):
self.stage_dict = {name: cls for name, cls in stage.preregistered_stages.items()}
await self.load_pipelines_from_db()
async def load_pipelines_from_db(self):
self.ap.logger.info('Loading pipelines from db...')
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
pipelines = result.all()
# load pipelines
for pipeline in pipelines:
await self.load_pipeline(pipeline)
async def load_pipeline(
self,
pipeline_entity: persistence_pipeline.LegacyPipeline
| sqlalchemy.Row[persistence_pipeline.LegacyPipeline]
| dict,
):
if isinstance(pipeline_entity, sqlalchemy.Row):
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity._mapping)
elif isinstance(pipeline_entity, dict):
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
# initialize stage containers according to pipeline_entity.stages
stage_containers: list[StageInstContainer] = []
for stage_name in pipeline_entity.stages:
stage_containers.append(StageInstContainer(inst_name=stage_name, inst=self.stage_dict[stage_name](self.ap)))
for stage_container in stage_containers:
await stage_container.inst.initialize(pipeline_entity.config)
runtime_pipeline = RuntimePipeline(self.ap, pipeline_entity, stage_containers)
self.pipelines.append(runtime_pipeline)
async def get_pipeline_by_uuid(self, uuid: str) -> RuntimePipeline | None:
for pipeline in self.pipelines:
if pipeline.pipeline_entity.uuid == uuid:
return pipeline
return None
async def remove_pipeline(self, uuid: str):
for pipeline in self.pipelines:
if pipeline.pipeline_entity.uuid == uuid:
self.pipelines.remove(pipeline)
return

View File

@@ -0,0 +1,71 @@
from __future__ import annotations
import asyncio
import typing
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
class QueryPool:
"""请求池请求获得调度进入pipeline之前保存在这里"""
query_id_counter: int = 0
pool_lock: asyncio.Lock
queries: list[pipeline_query.Query]
cached_queries: dict[int, pipeline_query.Query]
"""Cached queries, used for plugin backward api call, will be removed after the query completely processed"""
condition: asyncio.Condition
def __init__(self):
self.query_id_counter = 0
self.pool_lock = asyncio.Lock()
self.queries = []
self.cached_queries = {}
self.condition = asyncio.Condition(self.pool_lock)
async def add_query(
self,
bot_uuid: str,
launcher_type: provider_session.LauncherTypes,
launcher_id: typing.Union[int, str],
sender_id: typing.Union[int, str],
message_event: platform_events.MessageEvent,
message_chain: platform_message.MessageChain,
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
pipeline_uuid: typing.Optional[str] = None,
) -> pipeline_query.Query:
async with self.condition:
query_id = self.query_id_counter
query = pipeline_query.Query(
bot_uuid=bot_uuid,
query_id=query_id,
launcher_type=launcher_type,
launcher_id=launcher_id,
sender_id=sender_id,
message_event=message_event,
message_chain=message_chain,
variables={},
resp_messages=[],
resp_message_chain=[],
adapter=adapter,
pipeline_uuid=pipeline_uuid,
)
self.queries.append(query)
self.cached_queries[query_id] = query
self.query_id_counter += 1
self.condition.notify_all()
async def __aenter__(self):
await self.pool_lock.acquire()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.pool_lock.release()

View File

@@ -0,0 +1,147 @@
from __future__ import annotations
import datetime
from .. import stage, entities
from langbot_plugin.api.entities.builtin.provider import message as provider_message
import langbot_plugin.api.entities.events as events
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('PreProcessor')
class PreProcessor(stage.PipelineStage):
"""Request pre-processing stage
Check out session, prompt, context, model, and content functions.
Rewrite:
- session
- prompt
- messages
- user_message
- use_model
- use_funcs
"""
async def process(
self,
query: pipeline_query.Query,
stage_inst_name: str,
) -> entities.StageProcessResult:
"""Process"""
selected_runner = query.pipeline_config['ai']['runner']['runner']
session = await self.ap.sess_mgr.get_session(query)
# When not local-agent, llm_model is None
try:
llm_model = (
await self.ap.model_mgr.get_model_by_uuid(query.pipeline_config['ai']['local-agent']['model'])
if selected_runner == 'local-agent'
else None
)
except ValueError:
self.ap.logger.warning(
f'LLM model {query.pipeline_config["ai"]["local-agent"]["model"] + " "}not found or not configured'
)
llm_model = None
conversation = await self.ap.sess_mgr.get_conversation(
query,
session,
query.pipeline_config['ai']['local-agent']['prompt'],
query.pipeline_uuid,
query.bot_uuid,
)
# 设置query
query.session = session
query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy()
if selected_runner == 'local-agent' and llm_model:
query.use_funcs = []
query.use_llm_model_uuid = llm_model.model_entity.uuid
if llm_model.model_entity.abilities.__contains__('func_call'):
# Get bound plugins and MCP servers for filtering tools
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
bound_mcp_servers = query.variables.get('_pipeline_bound_mcp_servers', None)
query.use_funcs = await self.ap.tool_mgr.get_all_tools(bound_plugins, bound_mcp_servers)
self.ap.logger.debug(f'Bound plugins: {bound_plugins}')
self.ap.logger.debug(f'Bound MCP servers: {bound_mcp_servers}')
self.ap.logger.debug(f'Use funcs: {query.use_funcs}')
variables = {
'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
'conversation_id': conversation.uuid,
'msg_create_time': (
int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp())
),
}
query.variables.update(variables)
# Check if this model supports vision, if not, remove all images
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
if (
selected_runner == 'local-agent'
and llm_model
and not llm_model.model_entity.abilities.__contains__('vision')
):
for msg in query.messages:
if isinstance(msg.content, list):
for me in msg.content:
if me.type == 'image_url':
msg.content.remove(me)
content_list: list[provider_message.ContentElement] = []
plain_text = ''
qoute_msg = query.pipeline_config['trigger'].get('misc', '').get('combine-quote-message')
for me in query.message_chain:
if isinstance(me, platform_message.Plain):
content_list.append(provider_message.ContentElement.from_text(me.text))
plain_text += me.text
elif isinstance(me, platform_message.Image):
if selected_runner != 'local-agent' or (
llm_model and llm_model.model_entity.abilities.__contains__('vision')
):
if me.base64 is not None:
content_list.append(provider_message.ContentElement.from_image_base64(me.base64))
elif isinstance(me, platform_message.File):
# if me.url is not None:
content_list.append(provider_message.ContentElement.from_file_url(me.url, me.name))
elif isinstance(me, platform_message.Quote) and qoute_msg:
for msg in me.origin:
if isinstance(msg, platform_message.Plain):
content_list.append(provider_message.ContentElement.from_text(msg.text))
elif isinstance(msg, platform_message.Image):
if selected_runner != 'local-agent' or (
llm_model and llm_model.model_entity.abilities.__contains__('vision')
):
if msg.base64 is not None:
content_list.append(provider_message.ContentElement.from_image_base64(msg.base64))
query.variables['user_message_text'] = plain_text
query.user_message = provider_message.Message(role='user', content=content_list)
# =========== 触发事件 PromptPreProcessing
event = events.PromptPreProcessing(
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
default_prompt=query.prompt.messages,
prompt=query.messages,
query=query,
)
# Get bound plugins for filtering
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
event_ctx = await self.ap.plugin_connector.emit_event(event, bound_plugins)
query.prompt.messages = event_ctx.event.default_prompt
query.messages = event_ctx.event.prompt
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -0,0 +1,33 @@
from __future__ import annotations
import abc
from ...core import app
from .. import entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class MessageHandler(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def handle(
self,
query: pipeline_query.Query,
) -> entities.StageProcessResult:
raise NotImplementedError
def cut_str(self, s: str) -> str:
"""
Take the first line of the string, up to 20 characters, if there are multiple lines, or more than 20 characters, add an ellipsis
"""
s0 = s.split('\n')[0]
if len(s0) > 20 or '\n' in s:
s0 = s0[:20] + '...'
return s0

View File

@@ -0,0 +1,128 @@
from __future__ import annotations
import uuid
import typing
import traceback
from .. import handler
from ... import entities
from ....provider import runner as runner_module
import langbot_plugin.api.entities.events as events
from ....utils import importutil
from ....provider import runners
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
importutil.import_modules_in_pkg(runners)
class ChatMessageHandler(handler.MessageHandler):
async def handle(
self,
query: pipeline_query.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理"""
# 调API
# 生成器
# 触发插件事件
event_class = (
events.PersonNormalMessageReceived
if query.launcher_type == provider_session.LauncherTypes.PERSON
else events.GroupNormalMessageReceived
)
event = event_class(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
text_message=str(query.message_chain),
query=query,
)
# Get bound plugins for filtering
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
event_ctx = await self.ap.plugin_connector.emit_event(event, bound_plugins)
is_create_card = False # 判断下是否需要创建流式卡片
if event_ctx.is_prevented_default():
if event_ctx.event.reply_message_chain is not None:
mc = event_ctx.event.reply_message_chain
query.resp_messages.append(mc)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
else:
if event_ctx.event.user_message_alter is not None:
# if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
query.user_message.content = event_ctx.event.user_message_alter
text_length = 0
try:
is_stream = await query.adapter.is_stream_output_supported()
except AttributeError:
is_stream = False
try:
for r in runner_module.preregistered_runners:
if r.name == query.pipeline_config['ai']['runner']['runner']:
runner = r(self.ap, query.pipeline_config)
break
else:
raise ValueError(f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}')
if is_stream:
resp_message_id = uuid.uuid4()
async for result in runner.run(query):
result.resp_message_id = str(resp_message_id)
if query.resp_messages:
query.resp_messages.pop()
if query.resp_message_chain:
query.resp_message_chain.pop()
# 此时连接外部 AI 服务正常,创建卡片
if not is_create_card: # 只有不是第一次才创建卡片
await query.adapter.create_message_card(str(resp_message_id), query.message_event)
is_create_card = True
query.resp_messages.append(result)
self.ap.logger.info(f'对话({query.query_id})流式响应: {self.cut_str(result.readable_str())}')
if result.content is not None:
text_length += len(result.content)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
async for result in runner.run(query):
query.resp_messages.append(result)
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
if result.content is not None:
text_length += len(result.content)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
query.session.using_conversation.messages.append(query.user_message)
query.session.using_conversation.messages.extend(query.resp_messages)
except Exception as e:
self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}')
traceback.print_exc()
hide_exception_info = query.pipeline_config['output']['misc']['hide-exception']
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice='请求失败' if hide_exception_info else f'{e}',
error_notice=f'{e}',
debug_notice=traceback.format_exc(),
)
finally:
# TODO statistics
pass

View File

@@ -0,0 +1,110 @@
from __future__ import annotations
import typing
from .. import handler
from ... import entities
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.events as events
class CommandHandler(handler.MessageHandler):
async def handle(
self,
query: pipeline_query.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""Process"""
full_command_text = str(query.message_chain).strip()
command_text = full_command_text[1:]
privilege = 1
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
privilege = 2
spt = command_text.split(' ')
event_class = (
events.PersonCommandSent
if query.launcher_type == provider_session.LauncherTypes.PERSON
else events.GroupCommandSent
)
event = event_class(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
command=spt[0],
params=spt[1:] if len(spt) > 1 else [],
text_message=full_command_text,
is_admin=(privilege == 2),
query=query,
)
# Get bound plugins for filtering
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
event_ctx = await self.ap.plugin_connector.emit_event(event, bound_plugins)
if event_ctx.is_prevented_default():
if event_ctx.event.reply_message_chain is not None:
mc = event_ctx.event.reply_message_chain
query.resp_messages.append(mc)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
else:
session = await self.ap.sess_mgr.get_session(query)
async for ret in self.ap.cmd_mgr.execute(
command_text=command_text, full_command_text=full_command_text, query=query, session=session
):
if ret.error is not None:
query.resp_messages.append(
provider_message.Message(
role='command',
content=str(ret.error),
)
)
self.ap.logger.info(f'Command({query.query_id}) error: {self.cut_str(str(ret.error))}')
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
elif (
ret.text is not None
or ret.image_url is not None
or ret.image_base64 is not None
or ret.file_url is not None
):
content: list[provider_message.ContentElement] = []
if ret.text is not None:
content.append(provider_message.ContentElement.from_text(ret.text))
if ret.image_url is not None:
content.append(provider_message.ContentElement.from_image_url(ret.image_url))
if ret.image_base64 is not None:
content.append(provider_message.ContentElement.from_image_base64(ret.image_base64))
if ret.file_url is not None:
# 此时为 file 类型
content.append(provider_message.ContentElement.from_file_url(ret.file_url, ret.file_name))
query.resp_messages.append(
provider_message.Message(
role='command',
content=content,
)
)
self.ap.logger.info(f'Command returned: {self.cut_str(str(content[0]))}')
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)

View File

@@ -0,0 +1,55 @@
from __future__ import annotations
from . import handler
from .handlers import chat, command
from .. import entities
from .. import stage
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('MessageProcessor')
class Processor(stage.PipelineStage):
"""请求实际处理阶段
通过命令处理器和聊天处理器处理消息。
改写:
- resp_messages
"""
cmd_handler: handler.MessageHandler
chat_handler: handler.MessageHandler
async def initialize(self, pipeline_config: dict):
self.cmd_handler = command.CommandHandler(self.ap)
self.chat_handler = chat.ChatMessageHandler(self.ap)
await self.cmd_handler.initialize()
await self.chat_handler.initialize()
async def process(
self,
query: pipeline_query.Query,
stage_inst_name: str,
) -> entities.StageProcessResult:
"""Process"""
message_text = str(query.message_chain).strip()
self.ap.logger.info(
f'Processing request from {query.launcher_type.value}_{query.launcher_id} ({query.query_id}): {message_text}'
)
async def generator():
cmd_prefix = self.ap.instance_config.data['command']['prefix']
cmd_enable = self.ap.instance_config.data['command'].get('enable', True)
if cmd_enable and any(message_text.startswith(prefix) for prefix in cmd_prefix):
handler_to_use = self.cmd_handler
else:
handler_to_use = self.chat_handler
async for result in handler_to_use.handle(query):
yield result
return generator()

View File

@@ -0,0 +1,68 @@
from __future__ import annotations
import abc
import typing
from ...core import app
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
def algo_class(name: str):
def decorator(cls: typing.Type[ReteLimitAlgo]) -> typing.Type[ReteLimitAlgo]:
cls.name = name
preregistered_algos.append(cls)
return cls
return decorator
class ReteLimitAlgo(metaclass=abc.ABCMeta):
"""限流算法抽象类"""
name: str = None
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def require_access(
self,
query: pipeline_query.Query,
launcher_type: str,
launcher_id: typing.Union[int, str],
) -> bool:
"""进入处理流程
这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。
Args:
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
launcher_id (int): 请求者ID
Returns:
bool: 是否允许进入处理流程若返回false则直接丢弃该请求
"""
raise NotImplementedError
@abc.abstractmethod
async def release_access(
self,
query: pipeline_query.Query,
launcher_type: str,
launcher_id: typing.Union[int, str],
):
"""退出处理流程
Args:
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
launcher_id (int): 请求者ID
"""
raise NotImplementedError

View File

@@ -0,0 +1,98 @@
from __future__ import annotations
import asyncio
import time
import typing
from .. import algo
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
# 固定窗口算法
class SessionContainer:
wait_lock: asyncio.Lock
records: dict[int, int]
"""访问记录key为每窗口长度的起始时间戳value为访问次数"""
def __init__(self):
self.wait_lock = asyncio.Lock()
self.records = {}
@algo.algo_class('fixwin')
class FixedWindowAlgo(algo.ReteLimitAlgo):
containers_lock: asyncio.Lock
"""访问记录容器锁"""
containers: dict[str, SessionContainer]
"""访问记录容器key为launcher_type launcher_id"""
async def initialize(self):
self.containers_lock = asyncio.Lock()
self.containers = {}
async def require_access(
self,
query: pipeline_query.Query,
launcher_type: str,
launcher_id: typing.Union[int, str],
) -> bool:
# 加锁,找容器
container: SessionContainer = None
session_name = f'{launcher_type}_{launcher_id}'
async with self.containers_lock:
container = self.containers.get(session_name)
if container is None:
container = SessionContainer()
self.containers[session_name] = container
# 等待锁
async with container.wait_lock:
# 获取窗口大小和限制
window_size = query.pipeline_config['safety']['rate-limit']['window-length']
limitation = query.pipeline_config['safety']['rate-limit']['limitation']
# TODO revert it
# if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
# window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size']
# limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit']
# 获取当前时间戳
now = int(time.time())
# 获取当前窗口的起始时间戳
now = now - now % window_size
# 获取当前窗口的访问次数
count = container.records.get(now, 0)
# 如果访问次数超过了限制
if count >= limitation:
if query.pipeline_config['safety']['rate-limit']['strategy'] == 'drop':
return False
elif query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait':
# 等待下一窗口
await asyncio.sleep(window_size - time.time() % window_size)
now = int(time.time())
now = now - now % window_size
if now not in container.records:
container.records = {}
container.records[now] = 1
else:
# 访问次数加一
container.records[now] = count + 1
# 返回True
return True
async def release_access(
self,
query: pipeline_query.Query,
launcher_type: str,
launcher_id: typing.Union[int, str],
):
pass

View File

@@ -0,0 +1,76 @@
from __future__ import annotations
import typing
from .. import entities, stage
from . import algo
from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import algos
importutil.import_modules_in_pkg(algos)
@stage.stage_class('RequireRateLimitOccupancy')
@stage.stage_class('ReleaseRateLimitOccupancy')
class RateLimit(stage.PipelineStage):
"""限速器控制阶段
不改写query只检查是否需要限速。
"""
algo: algo.ReteLimitAlgo
async def initialize(self, pipeline_config: dict):
algo_name = 'fixwin'
algo_class = None
for algo_cls in algo.preregistered_algos:
if algo_cls.name == algo_name:
algo_class = algo_cls
break
else:
raise ValueError(f'未知的限速算法: {algo_name}')
self.algo = algo_class(self.ap)
await self.algo.initialize()
async def process(
self,
query: pipeline_query.Query,
stage_inst_name: str,
) -> typing.Union[
entities.StageProcessResult,
typing.AsyncGenerator[entities.StageProcessResult, None],
]:
"""处理"""
if stage_inst_name == 'RequireRateLimitOccupancy':
if await self.algo.require_access(
query,
query.launcher_type.value,
query.launcher_id,
):
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query,
)
else:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
console_notice=f'根据限速规则忽略 {query.launcher_type.value}:{query.launcher_id} 消息',
user_notice='请求数超过限速器设定值,已丢弃本消息。',
)
elif stage_inst_name == 'ReleaseRateLimitOccupancy':
await self.algo.release_access(
query,
query.launcher_type.value,
query.launcher_id,
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query,
)

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
import random
import asyncio
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.provider.message as provider_message
from .. import stage, entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('SendResponseBackStage')
class SendResponseBackStage(stage.PipelineStage):
"""发送响应消息"""
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理"""
random_range = (
query.pipeline_config['output']['force-delay']['min'],
query.pipeline_config['output']['force-delay']['max'],
)
random_delay = random.uniform(*random_range)
self.ap.logger.debug('根据规则强制延迟回复: %s s', random_delay)
await asyncio.sleep(random_delay)
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
query.message_event, platform_events.GroupMessage
):
query.resp_message_chain[-1].insert(0, platform_message.At(target=query.message_event.sender.id))
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
has_chunks = any(isinstance(msg, provider_message.MessageChunk) for msg in query.resp_messages)
# TODO 命令与流式的兼容性问题
if await query.adapter.is_stream_output_supported() and has_chunks:
is_final = [msg.is_final for msg in query.resp_messages][0]
await query.adapter.reply_message_chunk(
message_source=query.message_event,
bot_message=query.resp_messages[-1],
message=query.resp_message_chain[-1],
quote_origin=quote_origin,
is_final=is_final,
)
else:
await query.adapter.reply_message(
message_source=query.message_event,
message=query.resp_message_chain[-1],
quote_origin=quote_origin,
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -0,0 +1,9 @@
import pydantic
import langbot_plugin.api.entities.builtin.platform.message as platform_message
class RuleJudgeResult(pydantic.BaseModel):
matching: bool = False
replacement: platform_message.MessageChain = None

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
from . import rule
from .. import stage, entities
from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import rules
importutil.import_modules_in_pkg(rules)
@stage.stage_class('GroupRespondRuleCheckStage')
class GroupRespondRuleCheckStage(stage.PipelineStage):
"""群组响应规则检查器
仅检查群消息是否符合规则。
"""
rule_matchers: list[rule.GroupRespondRule]
"""检查器实例"""
async def initialize(self, pipeline_config: dict):
"""初始化检查器"""
self.rule_matchers = []
for rule_matcher in rule.preregisetered_rules:
rule_inst = rule_matcher(self.ap)
await rule_inst.initialize()
self.rule_matchers.append(rule_inst)
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
if query.launcher_type.value != 'group': # 只处理群消息
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
rules = query.pipeline_config['trigger']['group-respond-rules']
use_rule = rules
# TODO revert it
# if str(query.launcher_id) in rules:
# use_rule = rules[str(query.launcher_id)]
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)
if res.matching:
query.message_chain = res.replacement
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query,
)
return entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)

View File

@@ -0,0 +1,46 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from . import entities
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
def rule_class(name: str):
def decorator(cls: typing.Type[GroupRespondRule]) -> typing.Type[GroupRespondRule]:
cls.name = name
preregisetered_rules.append(cls)
return cls
return decorator
class GroupRespondRule(metaclass=abc.ABCMeta):
"""群组响应规则的抽象类"""
name: str
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def match(
self,
message_text: str,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: pipeline_query.Query,
) -> entities.RuleJudgeResult:
"""判断消息是否匹配规则"""
raise NotImplementedError

View File

@@ -0,0 +1,34 @@
from __future__ import annotations
from .. import rule as rule_model
from .. import entities
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@rule_model.rule_class('at-bot')
class AtBotRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: pipeline_query.Query,
) -> entities.RuleJudgeResult:
found = False
def remove_at(message_chain: platform_message.MessageChain):
nonlocal found
for component in message_chain.root:
if isinstance(component, platform_message.At) and str(component.target) == str(
query.adapter.bot_account_id
):
message_chain.remove(component)
found = True
break
remove_at(message_chain)
remove_at(message_chain) # 回复消息时会at两次检查并删除重复的
return entities.RuleJudgeResult(matching=found, replacement=message_chain)

View File

@@ -0,0 +1,30 @@
from .. import rule as rule_model
from .. import entities
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@rule_model.rule_class('prefix')
class PrefixRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: pipeline_query.Query,
) -> entities.RuleJudgeResult:
prefixes = rule_dict['prefix']
for prefix in prefixes:
if message_text.startswith(prefix):
# 查找第一个plain元素
for me in message_chain:
if isinstance(me, platform_message.Plain):
me.text = me.text[len(prefix) :]
return entities.RuleJudgeResult(
matching=True,
replacement=message_chain,
)
return entities.RuleJudgeResult(matching=False, replacement=message_chain)

View File

@@ -0,0 +1,21 @@
import random
from .. import rule as rule_model
from .. import entities
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@rule_model.rule_class('random')
class RandomRespRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: pipeline_query.Query,
) -> entities.RuleJudgeResult:
random_rate = rule_dict['random']
return entities.RuleJudgeResult(matching=random.random() < random_rate, replacement=message_chain)

View File

@@ -0,0 +1,30 @@
import re
from .. import rule as rule_model
from .. import entities
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@rule_model.rule_class('regexp')
class RegExpRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: pipeline_query.Query,
) -> entities.RuleJudgeResult:
regexps = rule_dict['regexp']
for regexp in regexps:
match = re.match(regexp, message_text)
if match:
return entities.RuleJudgeResult(
matching=True,
replacement=message_chain,
)
return entities.RuleJudgeResult(matching=False, replacement=message_chain)

View File

@@ -0,0 +1,44 @@
from __future__ import annotations
import abc
import typing
from ..core import app
from . import entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_stages: dict[str, type[PipelineStage]] = {}
def stage_class(name: str) -> typing.Callable[[type[PipelineStage]], type[PipelineStage]]:
def decorator(cls: type[PipelineStage]) -> type[PipelineStage]:
preregistered_stages[name] = cls
return cls
return decorator
class PipelineStage(metaclass=abc.ABCMeta):
"""流水线阶段"""
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self, pipeline_config: dict):
"""初始化"""
pass
@abc.abstractmethod
async def process(
self,
query: pipeline_query.Query,
stage_inst_name: str,
) -> typing.Union[
entities.StageProcessResult,
typing.AsyncGenerator[entities.StageProcessResult, None],
]:
"""处理"""
raise NotImplementedError

View File

@@ -0,0 +1,141 @@
from __future__ import annotations
import typing
from .. import entities
from .. import stage
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.events as events
@stage.stage_class('ResponseWrapper')
class ResponseWrapper(stage.PipelineStage):
"""回复包装阶段
把回复的 message 包装成人类识读的形式。
改写:
- resp_message_chain
"""
async def initialize(self, pipeline_config: dict):
pass
async def process(
self,
query: pipeline_query.Query,
stage_inst_name: str,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理"""
# 如果 resp_messages[-1] 已经是 MessageChain 了
if isinstance(query.resp_messages[-1], platform_message.MessageChain):
query.resp_message_chain.append(query.resp_messages[-1])
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
if query.resp_messages[-1].role == 'command':
query.resp_message_chain.append(
query.resp_messages[-1].get_content_platform_message_chain(prefix_text='[bot] ')
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
elif query.resp_messages[-1].role == 'plugin':
query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain())
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
if query.resp_messages[-1].role == 'assistant':
result = query.resp_messages[-1]
session = await self.ap.sess_mgr.get_session(query)
reply_text = ''
if result.content: # 有内容
reply_text = str(result.get_content_platform_message_chain())
# ============= 触发插件事件 ===============
event = events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls]
if result.tool_calls is not None
else [],
query=query,
)
# Get bound plugins for filtering
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
event_ctx = await self.ap.plugin_connector.emit_event(event, bound_plugins)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
)
else:
if event_ctx.event.reply_message_chain is not None:
query.resp_message_chain.append(event_ctx.event.reply_message_chain)
else:
query.resp_message_chain.append(result.get_content_platform_message_chain())
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query,
)
if result.tool_calls is not None and len(result.tool_calls) > 0: # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain.append(
platform_message.MessageChain([platform_message.Plain(text=reply_text)])
)
if query.pipeline_config['output']['misc']['track-function-calls']:
event = events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls]
if result.tool_calls is not None
else [],
query=query,
)
# Get bound plugins for filtering
bound_plugins = query.variables.get('_pipeline_bound_plugins', None)
event_ctx = await self.ap.plugin_connector.emit_event(event, bound_plugins)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
)
else:
if event_ctx.event.reply_message_chain is not None:
query.resp_message_chain.append(event_ctx.event.reply_message_chain)
else:
query.resp_message_chain.append(
platform_message.MessageChain([platform_message.Plain(text=reply_text)])
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query,
)