style: introduce ruff as linter and formatter (#1356)

* style: remove necessary imports

* style: fix F841

* style: fix F401

* style: fix F811

* style: fix E402

* style: fix E721

* style: fix E722

* style: fix E722

* style: fix F541

* style: ruff format

* style: all passed

* style: add ruff in deps

* style: more ignores in ruff.toml

* style: add pre-commit
This commit is contained in:
Junyan Qin (Chin)
2025-04-29 17:24:07 +08:00
committed by GitHub
parent 09e70d70e9
commit 209f16af76
240 changed files with 5307 additions and 4689 deletions

View File

@@ -1,15 +1,13 @@
from __future__ import annotations
import re
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@stage.stage_class('BanSessionCheckStage')
class BanSessionCheckStage(stage.PipelineStage):
"""访问控制处理阶段
仅检查query中群号或个人号是否在访问控制列表中。
"""
@@ -17,26 +15,24 @@ class BanSessionCheckStage(stage.PipelineStage):
pass
async def process(
self,
query: core_entities.Query,
stage_inst_name: str
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
found = False
mode = query.pipeline_config['trigger']['access-control']['mode']
sess_list = query.pipeline_config['trigger']['access-control'][mode]
if (query.launcher_type.value == 'group' and 'group_*' in sess_list) \
or (query.launcher_type.value == 'person' and 'person_*' in sess_list):
if (query.launcher_type.value == 'group' and 'group_*' in sess_list) or (
query.launcher_type.value == 'person' and 'person_*' in sess_list
):
found = True
else:
for sess in sess_list:
if sess == f"{query.launcher_type.value}_{query.launcher_id}":
if sess == f'{query.launcher_type.value}_{query.launcher_id}':
found = True
break
ctn = False
if mode == 'whitelist':
@@ -45,7 +41,11 @@ class BanSessionCheckStage(stage.PipelineStage):
ctn = not found
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE if ctn else entities.ResultType.INTERRUPT,
result_type=entities.ResultType.CONTINUE
if ctn
else entities.ResultType.INTERRUPT,
new_query=query,
console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}' if not ctn else ''
console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}'
if not ctn
else '',
)

View File

@@ -4,20 +4,21 @@ from ...core import app
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from . import filter as filter_model, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine
from ...provider import entities as llm_entities
from ...platform.types import message as platform_message
from ...platform.types import events as platform_events
from ...platform.types import entities as platform_entities
from ...utils import importutil
from . import filters
importutil.import_modules_in_pkg(filters)
@stage.stage_class('PostContentFilterStage')
@stage.stage_class('PreContentFilterStage')
class ContentFilterStage(stage.PipelineStage):
"""内容过滤阶段
前置:
检查消息是否符合规则,不符合则拦截。
改写:
@@ -36,13 +37,12 @@ class ContentFilterStage(stage.PipelineStage):
super().__init__(ap)
async def initialize(self, pipeline_config: dict):
filters_required = [
"content-ignore",
'content-ignore',
]
if pipeline_config['safety']['content-filter']['check-sensitive-words']:
filters_required.append("ban-word-filter")
filters_required.append('ban-word-filter')
# TODO revert it
# if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
@@ -50,9 +50,7 @@ class ContentFilterStage(stage.PipelineStage):
for filter in filter_model.preregistered_filters:
if filter.name in filters_required:
self.filter_chain.append(
filter(self.ap)
)
self.filter_chain.append(filter(self.ap))
for filter in self.filter_chain:
await filter.initialize()
@@ -68,8 +66,7 @@ class ContentFilterStage(stage.PipelineStage):
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
else:
for filter in self.filter_chain:
@@ -78,26 +75,25 @@ class ContentFilterStage(stage.PipelineStage):
if result.level in [
filter_entities.ResultLevel.BLOCK,
filter_entities.ResultLevel.MASKED
filter_entities.ResultLevel.MASKED,
]:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice=result.user_notice,
console_notice=result.console_notice
console_notice=result.console_notice,
)
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
message = result.replacement
query.message_chain = platform_message.MessageChain(
platform_message.Plain(message)
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
async def _post_process(
self,
message: str,
@@ -108,8 +104,7 @@ class ContentFilterStage(stage.PipelineStage):
"""
if query.pipeline_config['safety']['content-filter']['scope'] == 'income-msg':
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
else:
message = message.strip()
@@ -122,30 +117,25 @@ class ContentFilterStage(stage.PipelineStage):
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice=result.user_notice,
console_notice=result.console_notice
console_notice=result.console_notice,
)
elif result.level in [
filter_entities.ResultLevel.PASS,
filter_entities.ResultLevel.MASKED
filter_entities.ResultLevel.MASKED,
]:
message = result.replacement
query.resp_messages[-1].content = message
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
async def process(
self,
query: core_entities.Query,
stage_inst_name: str
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
"""处理
"""
"""处理"""
if stage_inst_name == 'PreContentFilterStage':
contain_non_text = False
text_components = [platform_message.Plain, platform_message.Source]
@@ -156,28 +146,24 @@ class ContentFilterStage(stage.PipelineStage):
break
if contain_non_text:
self.ap.logger.debug(f"消息中包含非文本消息,跳过内容过滤器检查。")
self.ap.logger.debug('消息中包含非文本消息,跳过内容过滤器检查。')
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
return await self._pre_process(
str(query.message_chain).strip(),
query
)
return await self._pre_process(str(query.message_chain).strip(), query)
elif stage_inst_name == 'PostContentFilterStage':
# 仅处理 query.resp_messages[-1].content 是 str 的情况
if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(query.resp_messages[-1].content, str):
return await self._post_process(
query.resp_messages[-1].content,
query
)
if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(
query.resp_messages[-1].content, str
):
return await self._post_process(query.resp_messages[-1].content, query)
else:
self.ap.logger.debug(f"resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。")
self.ap.logger.debug(
'resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。'
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
else:
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')

View File

@@ -1,14 +1,11 @@
import typing
import enum
import pydantic.v1 as pydantic
from ...provider import entities as llm_entities
class ResultLevel(enum.Enum):
"""结果等级"""
PASS = enum.auto()
"""通过"""
@@ -24,6 +21,7 @@ class ResultLevel(enum.Enum):
class EnableStage(enum.Enum):
"""启用阶段"""
PRE = enum.auto()
"""预处理"""
@@ -55,14 +53,15 @@ class FilterResult(pydantic.BaseModel):
class ManagerResultLevel(enum.Enum):
"""处理器结果等级"""
CONTINUE = enum.auto()
"""继续"""
INTERRUPT = enum.auto()
"""中断"""
class FilterManagerResult(pydantic.BaseModel):
class FilterManagerResult(pydantic.BaseModel):
level: ManagerResultLevel
replacement: str

View File

@@ -5,14 +5,13 @@ import typing
from ...core import app, entities as core_entities
from . import entities
from ...provider import entities as llm_entities
preregistered_filters: list[typing.Type[ContentFilter]] = []
def filter_class(
name: str
name: str,
) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]:
"""内容过滤器类装饰器
@@ -22,6 +21,7 @@ def filter_class(
Returns:
typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: 装饰器
"""
def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]:
assert issubclass(cls, ContentFilter)
@@ -53,23 +53,21 @@ class ContentFilter(metaclass=abc.ABCMeta):
entity.EnableStage.PRE: 消息请求AI前此时需要检查的内容是用户的输入消息。
entity.EnableStage.POST: 消息请求AI后此时需要检查的内容是AI的回复消息。
"""
return [
entities.EnableStage.PRE,
entities.EnableStage.POST
]
return [entities.EnableStage.PRE, entities.EnableStage.POST]
async def initialize(self):
"""初始化过滤器
"""
"""初始化过滤器"""
pass
@abc.abstractmethod
async def process(self, query: core_entities.Query, message: str=None, image_url=None) -> entities.FilterResult:
async def process(
self, query: core_entities.Query, message: str = None, image_url=None
) -> entities.FilterResult:
"""处理消息
分为前后阶段,具体取决于 enable_stages 的值。
对于内容过滤器来说,不需要考虑消息所处的阶段,只需要检查消息内容即可。
Args:
message (str): 需要检查的内容
image_url (str): 要检查的图片的 URL

View File

@@ -7,11 +7,11 @@ from .. import filter as filter_model
from ....core import entities as core_entities
BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}"
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}'
BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
@filter_model.filter_class("baidu-cloud-examine")
@filter_model.filter_class('baidu-cloud-examine')
class BaiduCloudExamine(filter_model.ContentFilter):
"""百度云内容审核"""
@@ -20,44 +20,52 @@ class BaiduCloudExamine(filter_model.ContentFilter):
async with session.post(
BAIDU_EXAMINE_TOKEN_URL,
params={
"grant_type": "client_credentials",
"client_id": self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'],
"client_secret": self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret']
}
'grant_type': 'client_credentials',
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine'][
'api-key'
],
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine'][
'api-secret'
],
},
) as resp:
return (await resp.json())['access_token']
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
async def process(
self, query: core_entities.Query, message: str
) -> entities.FilterResult:
async with aiohttp.ClientSession() as session:
async with session.post(
BAIDU_EXAMINE_URL.format(await self._get_token()),
headers={'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'},
data=f"text={message}".encode('utf-8')
headers={
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json',
},
data=f'text={message}'.encode('utf-8'),
) as resp:
result = await resp.json()
if "error_code" in result:
if 'error_code' in result:
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement=message,
user_notice='',
console_notice=f"百度云判定出错,错误信息:{result['error_msg']}"
console_notice=f'百度云判定出错,错误信息:{result["error_msg"]}',
)
else:
conclusion = result["conclusion"]
conclusion = result['conclusion']
if conclusion in ("合规"):
if conclusion in ('合规'):
return entities.FilterResult(
level=entities.ResultLevel.PASS,
replacement=message,
user_notice='',
console_notice=f"百度云判定结果:{conclusion}"
console_notice=f'百度云判定结果:{conclusion}',
)
else:
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement=message,
user_notice="消息中存在不合适的内容, 请修改",
console_notice=f"百度云判定结果:{conclusion}"
user_notice='消息中存在不合适的内容, 请修改',
console_notice=f'百度云判定结果:{conclusion}',
)

View File

@@ -6,14 +6,16 @@ from .. import entities
from ....core import entities as core_entities
@filter_model.filter_class("ban-word-filter")
@filter_model.filter_class('ban-word-filter')
class BanWordFilter(filter_model.ContentFilter):
"""根据内容过滤"""
async def initialize(self):
pass
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
async def process(
self, query: core_entities.Query, message: str
) -> entities.FilterResult:
found = False
for word in self.ap.sensitive_meta.data['words']:
@@ -23,9 +25,10 @@ class BanWordFilter(filter_model.ContentFilter):
found = True
for i in range(len(match)):
if self.ap.sensitive_meta.data['mask_word'] == "":
if self.ap.sensitive_meta.data['mask_word'] == '':
message = message.replace(
match[i], self.ap.sensitive_meta.data['mask'] * len(match[i])
match[i],
self.ap.sensitive_meta.data['mask'] * len(match[i]),
)
else:
message = message.replace(
@@ -36,5 +39,5 @@ class BanWordFilter(filter_model.ContentFilter):
level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS,
replacement=message,
user_notice='消息中存在不合适的内容, 请修改' if found else '',
console_notice=''
)
console_notice='',
)

View File

@@ -6,7 +6,7 @@ from .. import filter as filter_model
from ....core import entities as core_entities
@filter_model.filter_class("content-ignore")
@filter_model.filter_class('content-ignore')
class ContentIgnore(filter_model.ContentFilter):
"""根据内容忽略消息"""
@@ -16,7 +16,9 @@ class ContentIgnore(filter_model.ContentFilter):
entities.EnableStage.PRE,
]
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
async def process(
self, query: core_entities.Query, message: str
) -> entities.FilterResult:
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
if message.startswith(rule):
@@ -24,9 +26,9 @@ class ContentIgnore(filter_model.ContentFilter):
level=entities.ResultLevel.BLOCK,
replacement='',
user_notice='',
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息'
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息',
)
if 'regexp' in query.pipeline_config['trigger']['ignore-rules']:
for rule in query.pipeline_config['trigger']['ignore-rules']['regexp']:
if re.search(rule, message):
@@ -34,12 +36,12 @@ class ContentIgnore(filter_model.ContentFilter):
level=entities.ResultLevel.BLOCK,
replacement='',
user_notice='',
console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息'
console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息',
)
return entities.FilterResult(
level=entities.ResultLevel.PASS,
replacement=message,
user_notice='',
console_notice=''
)
console_notice='',
)

View File

@@ -1,18 +1,14 @@
from __future__ import annotations
import asyncio
import typing
import traceback
from ..core import app, entities
from . import entities as pipeline_entities
from ..plugin import events
from ..platform.types import message as platform_message
class Controller:
"""总控制器
"""
"""总控制器"""
ap: app.Application
semaphore: asyncio.Semaphore = None
@@ -20,11 +16,12 @@ class Controller:
def __init__(self, ap: app.Application):
self.ap = ap
self.semaphore = asyncio.Semaphore(self.ap.instance_config.data['concurrency']['pipeline'])
self.semaphore = asyncio.Semaphore(
self.ap.instance_config.data['concurrency']['pipeline']
)
async def consumer(self):
"""事件处理循环
"""
"""事件处理循环"""
try:
while True:
selected_query: entities.Query = None
@@ -35,7 +32,9 @@ class Controller:
for query in queries:
session = await self.ap.sess_mgr.get_session(query)
self.ap.logger.debug(f"Checking query {query} session {session}")
self.ap.logger.debug(
f'Checking query {query} session {session}'
)
if not session.semaphore.locked():
selected_query = query
@@ -56,30 +55,40 @@ class Controller:
# find pipeline
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
# Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected.
bot = await self.ap.platform_mgr.get_bot_by_uuid(selected_query.bot_uuid)
bot = await self.ap.platform_mgr.get_bot_by_uuid(
selected_query.bot_uuid
)
if bot:
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(bot.bot_entity.use_pipeline_uuid)
pipeline = (
await self.ap.pipeline_mgr.get_pipeline_by_uuid(
bot.bot_entity.use_pipeline_uuid
)
)
if pipeline:
await pipeline.run(selected_query)
async with self.ap.query_pool:
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
(
await self.ap.sess_mgr.get_session(selected_query)
).semaphore.release()
# 通知其他协程,有新的请求可以处理了
self.ap.query_pool.condition.notify_all()
self.ap.task_mgr.create_task(
_process_query(selected_query),
kind="query",
name=f"query-{selected_query.query_id}",
scopes=[entities.LifecycleControlScope.APPLICATION, entities.LifecycleControlScope.PLATFORM],
kind='query',
name=f'query-{selected_query.query_id}',
scopes=[
entities.LifecycleControlScope.APPLICATION,
entities.LifecycleControlScope.PLATFORM,
],
)
except Exception as e:
# traceback.print_exc()
self.ap.logger.error(f"控制器循环出错: {e}")
self.ap.logger.error(f"Traceback: {traceback.format_exc()}")
self.ap.logger.error(f'控制器循环出错: {e}')
self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
async def run(self):
"""运行控制器
"""
"""运行控制器"""
await self.consumer()

View File

@@ -10,7 +10,6 @@ from ..core import entities
class ResultType(enum.Enum):
CONTINUE = enum.auto()
"""继续流水线"""
@@ -19,12 +18,18 @@ class ResultType(enum.Enum):
class StageProcessResult(pydantic.BaseModel):
result_type: ResultType
new_query: entities.Query
user_notice: typing.Optional[typing.Union[str, list[platform_message.MessageComponent], platform_message.MessageChain, None]] = []
user_notice: typing.Optional[
typing.Union[
str,
list[platform_message.MessageComponent],
platform_message.MessageChain,
None,
]
] = []
"""只要设置了就会发送给用户"""
console_notice: typing.Optional[str] = ''

View File

@@ -2,18 +2,19 @@ from __future__ import annotations
import os
import traceback
from PIL import Image, ImageDraw, ImageFont
from ...core import app
from . import strategy
from .strategies import image, forward
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from ...platform.types import message as platform_message
from ...utils import importutil
from . import strategies
importutil.import_modules_in_pkg(strategies)
@stage.stage_class("LongTextProcessStage")
@stage.stage_class('LongTextProcessStage')
class LongTextProcessStage(stage.PipelineStage):
"""长消息处理阶段
@@ -31,34 +32,48 @@ class LongTextProcessStage(stage.PipelineStage):
# 检查是否存在
if not os.path.exists(use_font):
# 若是windows系统使用微软雅黑
if os.name == "nt":
use_font = "C:/Windows/Fonts/msyh.ttc"
if os.name == 'nt':
use_font = 'C:/Windows/Fonts/msyh.ttc'
if not os.path.exists(use_font):
self.ap.logger.warn("未找到字体文件且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在配置文件中调整相关设置。")
config['blob_message_strategy'] = "forward"
self.ap.logger.warn(
'未找到字体文件且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在配置文件中调整相关设置。'
)
config['blob_message_strategy'] = 'forward'
else:
self.ap.logger.info("使用Windows自带字体" + use_font)
self.ap.logger.info('使用Windows自带字体' + use_font)
config['font-path'] = use_font
else:
self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。")
self.ap.logger.warn(
'未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。'
)
pipeline_config['output']['long-text-processing']['strategy'] = "forward"
except:
pipeline_config['output']['long-text-processing'][
'strategy'
] = 'forward'
except Exception:
traceback.print_exc()
self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。".format(use_font))
self.ap.logger.error(
'加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。'.format(
use_font
)
)
pipeline_config['output']['long-text-processing']['strategy'] = "forward"
pipeline_config['output']['long-text-processing']['strategy'] = (
'forward'
)
for strategy_cls in strategy.preregistered_strategies:
if strategy_cls.name == config['strategy']:
self.strategy_impl = strategy_cls(self.ap)
break
else:
raise ValueError(f"未找到名为 {config['strategy']} 的长消息处理策略")
raise ValueError(f'未找到名为 {config["strategy"]} 的长消息处理策略')
await self.strategy_impl.initialize()
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
# 检查是否包含非 Plain 组件
contains_non_plain = False
@@ -66,13 +81,19 @@ class LongTextProcessStage(stage.PipelineStage):
if not isinstance(msg, platform_message.Plain):
contains_non_plain = True
break
if contains_non_plain:
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
elif len(str(query.resp_message_chain[-1])) > query.pipeline_config['output']['long-text-processing']['threshold']:
query.resp_message_chain[-1] = platform_message.MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
self.ap.logger.debug('消息中包含非 Plain 组件,跳过长消息处理。')
elif (
len(str(query.resp_message_chain[-1]))
> query.pipeline_config['output']['long-text-processing']['threshold']
):
query.resp_message_chain[-1] = platform_message.MessageChain(
await self.strategy_impl.process(
str(query.resp_message_chain[-1]), query
)
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)

View File

@@ -1,8 +1,6 @@
# 转发消息组件
from __future__ import annotations
import typing
import pydantic.v1 as pydantic
from .. import strategy as strategy_model
from ....core import entities as core_entities
@@ -13,29 +11,27 @@ ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay
Forward = platform_message.Forward
@strategy_model.strategy_class("forward")
@strategy_model.strategy_class('forward')
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
async def process(
self, message: str, query: core_entities.Query
) -> list[platform_message.MessageComponent]:
display = ForwardMessageDiaplay(
title="群聊的聊天记录",
brief="[聊天记录]",
source="聊天记录",
preview=["QQ用户: "+message],
summary="查看1条转发消息"
title='群聊的聊天记录',
brief='[聊天记录]',
source='聊天记录',
preview=['QQ用户: ' + message],
summary='查看1条转发消息',
)
node_list = [
platform_message.ForwardMessageNode(
sender_id=query.adapter.bot_account_id,
sender_name='QQ用户',
message_chain=platform_message.MessageChain([message])
message_chain=platform_message.MessageChain([message]),
)
]
forward = Forward(
display=display,
node_list=node_list
)
forward = Forward(display=display, node_list=node_list)
return [forward]

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import typing
import os
import base64
import time
@@ -15,26 +14,30 @@ from .. import strategy as strategy_model
from ....core import entities as core_entities
@strategy_model.strategy_class("image")
@strategy_model.strategy_class('image')
class Text2ImageStrategy(strategy_model.LongTextStrategy):
async def initialize(self):
pass
@functools.lru_cache(maxsize=16)
def get_font(self, query: core_entities.Query):
return ImageFont.truetype(query.pipeline_config['output']['long-text-processing']['font-path'], 32, encoding="utf-8")
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
return ImageFont.truetype(
query.pipeline_config['output']['long-text-processing']['font-path'],
32,
encoding='utf-8',
)
async def process(
self, message: str, query: core_entities.Query
) -> list[platform_message.MessageComponent]:
img_path = self.text_to_image(
text_str=message,
save_as='temp/{}.png'.format(int(time.time())),
query=query
query=query,
)
compressed_path, size = self.compress_image(
img_path,
outfile="temp/{}_compressed.png".format(int(time.time()))
img_path, outfile='temp/{}_compressed.png'.format(int(time.time()))
)
with open(compressed_path, 'rb') as f:
@@ -93,13 +96,11 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
resultIndex.append(v)
return resultIndex
def get_size(self, file):
# 获取文件大小:KB
size = os.path.getsize(file)
return size / 1024
def get_outfile(self, infile, outfile):
if outfile:
return outfile
@@ -107,7 +108,6 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
outfile = '{}-out{}'.format(dir, suffix)
return outfile
def compress_image(self, infile, outfile='', kb=100, step=20, quality=90):
"""不改变图片尺寸压缩到指定大小
:param infile: 压缩源文件
@@ -130,24 +130,28 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
o_size = self.get_size(outfile)
return outfile, self.get_size(outfile)
def text_to_image(
self,
text_str: str,
save_as='temp.png',
width=800,
query: core_entities.Query = None,
):
text_str = text_str.replace('\t', ' ')
def text_to_image(self, text_str: str, save_as="temp.png", width=800, query: core_entities.Query = None):
text_str = text_str.replace("\t", " ")
# 分行
lines = text_str.split('\n')
# 计算并分割
final_lines = []
text_width = width-80
text_width = width - 80
self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width))
self.ap.logger.debug('lines: {}, text_width: {}'.format(lines, text_width))
for line in lines:
# 如果长了就分割
line_width = self.get_font(query).getlength(line)
self.ap.logger.debug("line_width: {}".format(line_width))
self.ap.logger.debug('line_width: {}'.format(line_width))
if line_width < text_width:
final_lines.append(line)
continue
@@ -161,7 +165,10 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
numbers = self.indexNumber(rest_text)
for number in numbers:
if number[1] < point < number[1] + len(number[0]) and number[1] != 0:
if (
number[1] < point < number[1] + len(number[0])
and number[1] != 0
):
point = number[1]
break
@@ -174,16 +181,23 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
else:
continue
# 准备画布
img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255))
img = Image.new(
'RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255)
)
draw = ImageDraw.Draw(img, mode='RGBA')
self.ap.logger.debug("正在绘制图片...")
self.ap.logger.debug('正在绘制图片...')
# 绘制正文
line_number = 0
offset_x = 20
offset_y = 30
for final_line in final_lines:
draw.text((offset_x, offset_y + 35 * line_number), final_line, fill=(0, 0, 0), font=self.text_render_font)
draw.text(
(offset_x, offset_y + 35 * line_number),
final_line,
fill=(0, 0, 0),
font=self.text_render_font,
)
# 遍历此行,检查是否有emoji
idx_in_line = 0
for ch in final_line:
@@ -196,7 +210,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
line_number += 1
self.ap.logger.debug("正在保存图片...")
self.ap.logger.debug('正在保存图片...')
img.save(save_as)
return save_as

View File

@@ -12,7 +12,7 @@ preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
def strategy_class(
name: str
name: str,
) -> typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]:
"""长文本处理策略类装饰器
@@ -36,8 +36,7 @@ def strategy_class(
class LongTextStrategy(metaclass=abc.ABCMeta):
"""长文本处理策略抽象类
"""
"""长文本处理策略抽象类"""
name: str
@@ -45,12 +44,14 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
async def process(
self, message: str, query: core_entities.Query
) -> list[platform_message.MessageComponent]:
"""处理长文本
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法

View File

@@ -3,33 +3,38 @@ from __future__ import annotations
from .. import stage, entities
from ...core import entities as core_entities
from . import truncator
from .truncators import round
from ...utils import importutil
from . import truncators
importutil.import_modules_in_pkg(truncators)
@stage.stage_class("ConversationMessageTruncator")
@stage.stage_class('ConversationMessageTruncator')
class ConversationMessageTruncator(stage.PipelineStage):
"""会话消息截断器
用于截断会话消息链,以适应平台消息长度限制。
"""
trun: truncator.Truncator
async def initialize(self, pipeline_config: dict):
use_method = "round"
use_method = 'round'
for trun in truncator.preregistered_truncators:
if trun.name == use_method:
self.trun = trun(self.ap)
break
else:
raise ValueError(f"未知的截断器: {use_method}")
raise ValueError(f'未知的截断器: {use_method}')
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理
"""
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
"""处理"""
query = await self.trun.truncate(query)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
result_type=entities.ResultType.CONTINUE, new_query=query
)

View File

@@ -10,7 +10,7 @@ preregistered_truncators: list[typing.Type[Truncator]] = []
def truncator_class(
name: str
name: str,
) -> typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]:
"""截断器类装饰器
@@ -20,6 +20,7 @@ def truncator_class(
Returns:
typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]: 装饰器
"""
def decorator(cls: typing.Type[Truncator]) -> typing.Type[Truncator]:
assert issubclass(cls, Truncator)
@@ -33,13 +34,12 @@ def truncator_class(
class Truncator(abc.ABC):
"""消息截断器基类
"""
"""消息截断器基类"""
name: str
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap

View File

@@ -4,14 +4,12 @@ from .. import truncator
from ....core import entities as core_entities
@truncator.truncator_class("round")
@truncator.truncator_class('round')
class RoundTruncator(truncator.Truncator):
"""前文回合数阶段器
"""
"""前文回合数阶段器"""
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
"""截断
"""
"""截断"""
max_round = query.pipeline_config['ai']['local-agent']['max-round']
temp_messages = []
@@ -26,7 +24,7 @@ class RoundTruncator(truncator.Truncator):
current_round += 1
else:
break
query.messages = temp_messages[::-1]
return query

View File

@@ -11,22 +11,39 @@ from ..entity.persistence import pipeline as persistence_pipeline
from . import stage
from ..platform.types import message as platform_message, events as platform_events
from ..plugin import events
from ..utils import importutil
from .resprule import resprule
from .bansess import bansess
from .cntfilter import cntfilter
from .process import process
from .longtext import longtext
from .respback import respback
from .wrapper import wrapper
from .preproc import preproc
from .ratelimit import ratelimit
from .msgtrun import msgtrun
from . import (
resprule,
bansess,
cntfilter,
process,
longtext,
respback,
wrapper,
preproc,
ratelimit,
msgtrun,
)
importutil.import_modules_in_pkgs(
[
resprule,
bansess,
cntfilter,
process,
longtext,
respback,
wrapper,
preproc,
ratelimit,
msgtrun,
]
)
class StageInstContainer():
"""阶段实例容器
"""
class StageInstContainer:
"""阶段实例容器"""
inst_name: str
@@ -48,7 +65,12 @@ class RuntimePipeline:
stage_containers: list[StageInstContainer]
"""阶段实例容器"""
def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[StageInstContainer]):
def __init__(
self,
ap: app.Application,
pipeline_entity: persistence_pipeline.LegacyPipeline,
stage_containers: list[StageInstContainer],
):
self.ap = ap
self.pipeline_entity = pipeline_entity
self.stage_containers = stage_containers
@@ -57,9 +79,10 @@ class RuntimePipeline:
query.pipeline_config = self.pipeline_entity.config
await self.process_query(query)
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
"""检查输出
"""
async def _check_output(
self, query: entities.Query, result: pipeline_entities.StageProcessResult
):
"""检查输出"""
if result.user_notice:
# 处理str类型
@@ -68,22 +91,19 @@ class RuntimePipeline:
platform_message.Plain(result.user_notice)
)
elif isinstance(result.user_notice, list):
result.user_notice = platform_message.MessageChain(
*result.user_notice
)
result.user_notice = platform_message.MessageChain(*result.user_notice)
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage):
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
query.message_event, platform_events.GroupMessage
):
result.user_notice.insert(
0,
platform_message.At(
query.message_event.sender.id
)
0, platform_message.At(query.message_event.sender.id)
)
await query.adapter.reply_message(
message_source=query.message_event,
message=result.user_notice,
quote_origin=query.pipeline_config['output']['misc']['quote-origin']
quote_origin=query.pipeline_config['output']['misc']['quote-origin'],
)
if result.debug_notice:
self.ap.logger.debug(result.debug_notice)
@@ -123,32 +143,44 @@ class RuntimePipeline:
stage_container = self.stage_containers[i]
query.current_stage = stage_container # 标记到 Query 对象里
result = stage_container.inst.process(query, stage_container.inst_name)
if isinstance(result, typing.Coroutine):
result = await result
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}")
self.ap.logger.debug(
f'Stage {stage_container.inst_name} processed query {query} res {result}'
)
await self._check_output(query, result)
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
self.ap.logger.debug(
f'Stage {stage_container.inst_name} interrupted query {query}'
)
break
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
query = result.new_query
elif isinstance(result, typing.AsyncGenerator): # 生成器
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen")
self.ap.logger.debug(
f'Stage {stage_container.inst_name} processed query {query} gen'
)
async for sub_result in result:
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}")
self.ap.logger.debug(
f'Stage {stage_container.inst_name} processed query {query} res {sub_result}'
)
await self._check_output(query, sub_result)
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
self.ap.logger.debug(
f'Stage {stage_container.inst_name} interrupted query {query}'
)
break
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
elif (
sub_result.result_type == pipeline_entities.ResultType.CONTINUE
):
query = sub_result.new_query
await self._execute_from_stage(i + 1, query)
break
@@ -156,12 +188,14 @@ class RuntimePipeline:
i += 1
async def process_query(self, query: entities.Query):
"""处理请求
"""
"""处理请求"""
try:
# ======== 触发 MessageReceived 事件 ========
event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived
event_type = (
events.PersonMessageReceived
if query.launcher_type == entities.LauncherTypes.PERSON
else events.GroupMessageReceived
)
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_type(
@@ -169,22 +203,26 @@ class RuntimePipeline:
launcher_id=query.launcher_id,
sender_id=query.sender_id,
message_chain=query.message_chain,
query=query
query=query,
)
)
if event_ctx.is_prevented_default():
return
self.ap.logger.debug(f"Processing query {query}")
self.ap.logger.debug(f'Processing query {query}')
await self._execute_from_stage(0, query)
except Exception as e:
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}")
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
inst_name = (
query.current_stage.inst_name if query.current_stage else 'unknown'
)
self.ap.logger.error(
f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}'
)
self.ap.logger.debug(f'Traceback: {traceback.format_exc()}')
finally:
self.ap.logger.debug(f"Query {query} processed")
self.ap.logger.debug(f'Query {query} processed')
class PipelineManager:
@@ -203,7 +241,9 @@ class PipelineManager:
self.pipelines = []
async def initialize(self):
self.stage_dict = {name: cls for name, cls in stage.preregistered_stages.items()}
self.stage_dict = {
name: cls for name, cls in stage.preregistered_stages.items()
}
await self.load_pipelines_from_db()
@@ -220,24 +260,31 @@ class PipelineManager:
for pipeline in pipelines:
await self.load_pipeline(pipeline)
async def load_pipeline(self, pipeline_entity: persistence_pipeline.LegacyPipeline | sqlalchemy.Row[persistence_pipeline.LegacyPipeline] | dict):
async def load_pipeline(
self,
pipeline_entity: persistence_pipeline.LegacyPipeline
| sqlalchemy.Row[persistence_pipeline.LegacyPipeline]
| dict,
):
if isinstance(pipeline_entity, sqlalchemy.Row):
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity._mapping)
pipeline_entity = persistence_pipeline.LegacyPipeline(
**pipeline_entity._mapping
)
elif isinstance(pipeline_entity, dict):
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
# initialize stage containers according to pipeline_entity.stages
stage_containers: list[StageInstContainer] = []
for stage_name in pipeline_entity.stages:
stage_containers.append(StageInstContainer(
inst_name=stage_name,
inst=self.stage_dict[stage_name](self.ap)
))
stage_containers.append(
StageInstContainer(
inst_name=stage_name, inst=self.stage_dict[stage_name](self.ap)
)
)
for stage_container in stage_containers:
await stage_container.inst.initialize(pipeline_entity.config)
runtime_pipeline = RuntimePipeline(self.ap, pipeline_entity, stage_containers)
self.pipelines.append(runtime_pipeline)
@@ -251,4 +298,4 @@ class PipelineManager:
for pipeline in self.pipelines:
if pipeline.pipeline_entity.uuid == uuid:
self.pipelines.remove(pipeline)
return
return

View File

@@ -47,7 +47,7 @@ class QueryPool:
message_chain=message_chain,
resp_messages=[],
resp_message_chain=[],
adapter=adapter
adapter=adapter,
)
self.queries.append(query)
self.query_id_counter += 1

View File

@@ -9,7 +9,7 @@ from ...plugin import events
from ...platform.types import message as platform_message
@stage.stage_class("PreProcessor")
@stage.stage_class('PreProcessor')
class PreProcessor(stage.PipelineStage):
"""请求预处理阶段
@@ -29,11 +29,12 @@ class PreProcessor(stage.PipelineStage):
query: core_entities.Query,
stage_inst_name: str,
) -> entities.StageProcessResult:
"""处理
"""
"""处理"""
session = await self.ap.sess_mgr.get_session(query)
conversation = await self.ap.sess_mgr.get_conversation(query, session, query.pipeline_config['ai']['local-agent']['prompt'])
conversation = await self.ap.sess_mgr.get_conversation(
query, session, query.pipeline_config['ai']['local-agent']['prompt']
)
# 设置query
query.session = session
@@ -42,17 +43,26 @@ class PreProcessor(stage.PipelineStage):
query.use_llm_model = conversation.use_llm_model
query.use_funcs = conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('tool_call') else None
query.use_funcs = (
conversation.use_funcs
if query.use_llm_model.model_entity.abilities.__contains__('tool_call')
else None
)
query.variables = {
"session_id": f"{query.session.launcher_type.value}_{query.session.launcher_id}",
"conversation_id": conversation.uuid,
"msg_create_time": int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp()),
'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
'conversation_id': conversation.uuid,
'msg_create_time': int(query.message_event.time)
if query.message_event.time
else int(datetime.datetime.now().timestamp()),
}
# Check if this model supports vision, if not, remove all images
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
if query.pipeline_config['ai']['runner']['runner'] == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'):
if (
query.pipeline_config['ai']['runner']['runner'] == 'local-agent'
and not query.use_llm_model.model_entity.abilities.__contains__('vision')
):
for msg in query.messages:
if isinstance(msg.content, list):
for me in msg.content:
@@ -61,16 +71,17 @@ class PreProcessor(stage.PipelineStage):
content_list = []
plain_text = ""
plain_text = ''
for me in query.message_chain:
if isinstance(me, platform_message.Plain):
content_list.append(
llm_entities.ContentElement.from_text(me.text)
)
content_list.append(llm_entities.ContentElement.from_text(me.text))
plain_text += me.text
elif isinstance(me, platform_message.Image):
if query.pipeline_config['ai']['runner']['runner'] != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__('vision'):
if (
query.pipeline_config['ai']['runner']['runner'] != 'local-agent'
or query.use_llm_model.model_entity.abilities.__contains__('vision')
):
if me.base64 is not None:
content_list.append(
llm_entities.ContentElement.from_image_base64(me.base64)
@@ -78,10 +89,7 @@ class PreProcessor(stage.PipelineStage):
query.variables['user_message_text'] = plain_text
query.user_message = llm_entities.Message(
role='user',
content=content_list
)
query.user_message = llm_entities.Message(role='user', content=content_list)
# =========== 触发事件 PromptPreProcessing
event_ctx = await self.ap.plugin_mgr.emit_event(
@@ -89,7 +97,7 @@ class PreProcessor(stage.PipelineStage):
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
default_prompt=query.prompt.messages,
prompt=query.messages,
query=query
query=query,
)
)
@@ -97,6 +105,5 @@ class PreProcessor(stage.PipelineStage):
query.messages = event_ctx.event.prompt
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)

View File

@@ -8,7 +8,6 @@ from .. import entities
class MessageHandler(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):

View File

@@ -3,33 +3,36 @@ from __future__ import annotations
import typing
import time
import traceback
import json
from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....provider import entities as llm_entities
from ....provider import runner as runner_module
from ....provider.runners import localagent, difysvapi, dashscopeapi
from ....plugin import events
from ....platform.types import message as platform_message
from ....utils import importutil
from ....provider import runners
importutil.import_modules_in_pkg(runners)
class ChatMessageHandler(handler.MessageHandler):
async def handle(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
"""处理"""
# 调API
# 生成器
# 触发插件事件
event_class = events.PersonNormalMessageReceived if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupNormalMessageReceived
event_class = (
events.PersonNormalMessageReceived
if query.launcher_type == core_entities.LauncherTypes.PERSON
else events.GroupNormalMessageReceived
)
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_class(
@@ -37,7 +40,7 @@ class ChatMessageHandler(handler.MessageHandler):
launcher_id=query.launcher_id,
sender_id=query.sender_id,
text_message=str(query.message_chain),
query=query
query=query,
)
)
@@ -48,16 +51,13 @@ class ChatMessageHandler(handler.MessageHandler):
query.resp_messages.append(mc)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
result_type=entities.ResultType.INTERRUPT, new_query=query
)
else:
if event_ctx.event.alter is not None:
# if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
query.user_message.content = event_ctx.event.alter
@@ -67,48 +67,52 @@ class ChatMessageHandler(handler.MessageHandler):
start_time = time.time()
try:
for r in runner_module.preregistered_runners:
if r.name == query.pipeline_config["ai"]["runner"]["runner"]:
if r.name == query.pipeline_config['ai']['runner']['runner']:
runner = r(self.ap, query.pipeline_config)
break
else:
raise ValueError(f"未找到请求运行器: {query.pipeline_config['ai']['runner']['runner']}")
raise ValueError(
f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}'
)
async for result in runner.run(query):
query.resp_messages.append(result)
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
self.ap.logger.info(
f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}'
)
if result.content is not None:
text_length += len(result.content)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
query.session.using_conversation.messages.append(query.user_message)
query.session.using_conversation.messages.extend(query.resp_messages)
except Exception as e:
self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}')
self.ap.logger.error(
f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}'
)
hide_exception_info = query.pipeline_config['output']['misc']['hide-exception']
hide_exception_info = query.pipeline_config['output']['misc'][
'hide-exception'
]
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice='请求失败' if hide_exception_info else f'{e}',
error_notice=f'{e}',
debug_notice=traceback.format_exc()
debug_notice=traceback.format_exc(),
)
finally:
await self.ap.ctr_mgr.usage.post_query_record(
session_type=query.session.launcher_type.value,
session_id=str(query.session.launcher_id),
query_ability_provider="LangBot.Chat",
query_ability_provider='LangBot.Chat',
usage=text_length,
model_name=query.use_model.name,
response_seconds=int(time.time() - start_time),

View File

@@ -11,24 +11,29 @@ from ....platform.types import message as platform_message
class CommandHandler(handler.MessageHandler):
async def handle(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
"""处理"""
command_text = str(query.message_chain).strip()[1:]
privilege = 1
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
if (
f'{query.launcher_type.value}_{query.launcher_id}'
in self.ap.instance_config.data['admins']
):
privilege = 2
spt = command_text.split(' ')
event_class = events.PersonCommandSent if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupCommandSent
event_class = (
events.PersonCommandSent
if query.launcher_type == core_entities.LauncherTypes.PERSON
else events.GroupCommandSent
)
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_class(
@@ -38,41 +43,35 @@ class CommandHandler(handler.MessageHandler):
command=spt[0],
params=spt[1:] if len(spt) > 1 else [],
text_message=str(query.message_chain),
is_admin=(privilege==2),
query=query
is_admin=(privilege == 2),
query=query,
)
)
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
mc = platform_message.MessageChain(event_ctx.event.reply)
query.resp_messages.append(mc)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
result_type=entities.ResultType.INTERRUPT, new_query=query
)
else:
if event_ctx.event.alter is not None:
query.message_chain = platform_message.MessageChain([
platform_message.Plain(event_ctx.event.alter)
])
query.message_chain = platform_message.MessageChain(
[platform_message.Plain(event_ctx.event.alter)]
)
session = await self.ap.sess_mgr.get_session(query)
async for ret in self.ap.cmd_mgr.execute(
command_text=command_text,
query=query,
session=session
command_text=command_text, query=query, session=session
):
if ret.error is not None:
query.resp_messages.append(
@@ -82,20 +81,18 @@ class CommandHandler(handler.MessageHandler):
)
)
self.ap.logger.info(f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}')
self.ap.logger.info(
f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}'
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
elif ret.text is not None or ret.image_url is not None:
content: list[llm_entities.ContentElement]= []
content: list[llm_entities.ContentElement] = []
if ret.text is not None:
content.append(
llm_entities.ContentElement.from_text(ret.text)
)
content.append(llm_entities.ContentElement.from_text(ret.text))
if ret.image_url is not None:
content.append(
@@ -112,11 +109,9 @@ class CommandHandler(handler.MessageHandler):
self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}')
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
result_type=entities.ResultType.INTERRUPT, new_query=query
)

View File

@@ -1,18 +1,16 @@
from __future__ import annotations
from ...core import app, entities as core_entities
from ...core import entities as core_entities
from . import handler
from .handlers import chat, command
from .. import entities
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from .. import stage
@stage.stage_class("MessageProcessor")
@stage.stage_class('MessageProcessor')
class Processor(stage.PipelineStage):
"""请求实际处理阶段
通过命令处理器和聊天处理器处理消息。
改写:
@@ -35,11 +33,12 @@ class Processor(stage.PipelineStage):
query: core_entities.Query,
stage_inst_name: str,
) -> entities.StageProcessResult:
"""处理
"""
"""处理"""
message_text = str(query.message_chain).strip()
self.ap.logger.info(f"处理 {query.launcher_type.value}_{query.launcher_id} 的请求({query.query_id}): {message_text}")
self.ap.logger.info(
f'处理 {query.launcher_type.value}_{query.launcher_id} 的请求({query.query_id}): {message_text}'
)
async def generator():
cmd_prefix = self.ap.instance_config.data['command']['prefix']
@@ -50,5 +49,5 @@ class Processor(stage.PipelineStage):
else:
async for result in self.chat_handler.handle(query):
yield result
return generator()

View File

@@ -7,19 +7,19 @@ from ...core import app, entities as core_entities
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
def algo_class(name: str):
def decorator(cls: typing.Type[ReteLimitAlgo]) -> typing.Type[ReteLimitAlgo]:
cls.name = name
preregistered_algos.append(cls)
return cls
return decorator
class ReteLimitAlgo(metaclass=abc.ABCMeta):
"""限流算法抽象类"""
name: str = None
ap: app.Application
@@ -31,11 +31,16 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def require_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
async def require_access(
self,
query: core_entities.Query,
launcher_type: str,
launcher_id: typing.Union[int, str],
) -> bool:
"""进入处理流程
这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。
Args:
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
launcher_id (int): 请求者ID
@@ -44,15 +49,19 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
bool: 是否允许进入处理流程若返回false则直接丢弃该请求
"""
raise NotImplementedError
@abc.abstractmethod
async def release_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]):
async def release_access(
self,
query: core_entities.Query,
launcher_type: str,
launcher_id: typing.Union[int, str],
):
"""退出处理流程
Args:
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
launcher_id (int): 请求者ID
"""
raise NotImplementedError

View File

@@ -5,9 +5,9 @@ import typing
from .. import algo
from ....core import entities as core_entities
# 固定窗口算法
class SessionContainer:
wait_lock: asyncio.Lock
records: dict[int, int]
@@ -18,9 +18,8 @@ class SessionContainer:
self.records = {}
@algo.algo_class("fixwin")
@algo.algo_class('fixwin')
class FixedWindowAlgo(algo.ReteLimitAlgo):
containers_lock: asyncio.Lock
"""访问记录容器锁"""
@@ -31,7 +30,12 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
self.containers_lock = asyncio.Lock()
self.containers = {}
async def require_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
async def require_access(
self,
query: core_entities.Query,
launcher_type: str,
launcher_id: typing.Union[int, str],
) -> bool:
# 加锁,找容器
container: SessionContainer = None
@@ -46,7 +50,6 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
# 等待锁
async with container.wait_lock:
# 获取窗口大小和限制
window_size = query.pipeline_config['safety']['rate-limit']['window-length']
limitation = query.pipeline_config['safety']['rate-limit']['limitation']
@@ -69,13 +72,15 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
if count >= limitation:
if query.pipeline_config['safety']['rate-limit']['strategy'] == 'drop':
return False
elif query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait':
elif (
query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait'
):
# 等待下一窗口
await asyncio.sleep(window_size - time.time() % window_size)
now = int(time.time())
now = now - now % window_size
if now not in container.records:
container.records = {}
container.records[now] = 1
@@ -85,6 +90,11 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
# 返回True
return True
async def release_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]):
async def release_access(
self,
query: core_entities.Query,
launcher_type: str,
launcher_id: typing.Union[int, str],
):
pass

View File

@@ -4,22 +4,25 @@ import typing
from .. import entities, stage
from . import algo
from .algos import fixedwin
from ...core import entities as core_entities
from ...utils import importutil
from . import algos
importutil.import_modules_in_pkg(algos)
@stage.stage_class("RequireRateLimitOccupancy")
@stage.stage_class("ReleaseRateLimitOccupancy")
@stage.stage_class('RequireRateLimitOccupancy')
@stage.stage_class('ReleaseRateLimitOccupancy')
class RateLimit(stage.PipelineStage):
"""限速器控制阶段
不改写query只检查是否需要限速。
"""
algo: algo.ReteLimitAlgo
async def initialize(self, pipeline_config: dict):
algo_name = 'fixwin'
algo_class = None
@@ -42,9 +45,8 @@ class RateLimit(stage.PipelineStage):
entities.StageProcessResult,
typing.AsyncGenerator[entities.StageProcessResult, None],
]:
"""处理
"""
if stage_inst_name == "RequireRateLimitOccupancy":
"""处理"""
if stage_inst_name == 'RequireRateLimitOccupancy':
if await self.algo.require_access(
query,
query.launcher_type.value,
@@ -58,10 +60,10 @@ class RateLimit(stage.PipelineStage):
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
console_notice=f"根据限速规则忽略 {query.launcher_type.value}:{query.launcher_id} 消息",
user_notice=f"请求数超过限速器设定值,已丢弃本消息。"
console_notice=f'根据限速规则忽略 {query.launcher_type.value}:{query.launcher_id} 消息',
user_notice='请求数超过限速器设定值,已丢弃本消息。',
)
elif stage_inst_name == "ReleaseRateLimitOccupancy":
elif stage_inst_name == 'ReleaseRateLimitOccupancy':
await self.algo.release_access(
query,
query.launcher_type.value,

View File

@@ -4,41 +4,38 @@ import random
import asyncio
from ...core import app
from ...platform.types import events as platform_events
from ...platform.types import message as platform_message
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@stage.stage_class("SendResponseBackStage")
@stage.stage_class('SendResponseBackStage')
class SendResponseBackStage(stage.PipelineStage):
"""发送响应消息
"""
"""发送响应消息"""
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理
"""
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
"""处理"""
random_range = (query.pipeline_config['output']['force-delay']['min'], query.pipeline_config['output']['force-delay']['max'])
random_range = (
query.pipeline_config['output']['force-delay']['min'],
query.pipeline_config['output']['force-delay']['max'],
)
random_delay = random.uniform(*random_range)
self.ap.logger.debug(
"根据规则强制延迟回复: %s s",
random_delay
)
self.ap.logger.debug('根据规则强制延迟回复: %s s', random_delay)
await asyncio.sleep(random_delay)
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage):
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
query.message_event, platform_events.GroupMessage
):
query.resp_message_chain[-1].insert(
0,
platform_message.At(
query.message_event.sender.id
)
0, platform_message.At(query.message_event.sender.id)
)
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
@@ -46,10 +43,9 @@ class SendResponseBackStage(stage.PipelineStage):
await query.adapter.reply_message(
message_source=query.message_event,
message=query.resp_message_chain[-1],
quote_origin=quote_origin
quote_origin=quote_origin,
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
result_type=entities.ResultType.CONTINUE, new_query=query
)

View File

@@ -4,7 +4,6 @@ from ...platform.types import message as platform_message
class RuleJudgeResult(pydantic.BaseModel):
matching: bool = False
replacement: platform_message.MessageChain = None

View File

@@ -1,16 +1,18 @@
from __future__ import annotations
from ...core import app
from . import entities as rule_entities, rule
from .rules import atbot, prefix, regexp, random
from . import rule
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from ...utils import importutil
from . import rules
importutil.import_modules_in_pkg(rules)
@stage.stage_class("GroupRespondRuleCheckStage")
@stage.stage_class('GroupRespondRuleCheckStage')
class GroupRespondRuleCheckStage(stage.PipelineStage):
"""群组响应规则检查器
@@ -21,8 +23,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
"""检查器实例"""
async def initialize(self, pipeline_config: dict):
"""初始化检查器
"""
"""初始化检查器"""
self.rule_matchers = []
@@ -31,12 +32,12 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
await rule_inst.initialize()
self.rule_matchers.append(rule_inst)
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
if query.launcher_type.value != 'group': # 只处理群消息
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
rules = query.pipeline_config['trigger']['group-respond-rules']
@@ -48,7 +49,9 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
# use_rule = rules[str(query.launcher_id)]
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)
res = await rule_matcher.match(
str(query.message_chain), query.message_chain, use_rule, query
)
if res.matching:
query.message_chain = res.replacement
@@ -56,8 +59,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
result_type=entities.ResultType.CONTINUE,
new_query=query,
)
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
result_type=entities.ResultType.INTERRUPT, new_query=query
)

View File

@@ -10,17 +10,19 @@ from ...platform.types import message as platform_message
preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
def rule_class(name: str):
def decorator(cls: typing.Type[GroupRespondRule]) -> typing.Type[GroupRespondRule]:
cls.name = name
preregisetered_rules.append(cls)
return cls
return decorator
class GroupRespondRule(metaclass=abc.ABCMeta):
"""群组响应规则的抽象类
"""
"""群组响应规则的抽象类"""
name: str
ap: app.Application
@@ -37,8 +39,7 @@ class GroupRespondRule(metaclass=abc.ABCMeta):
message_text: str,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: core_entities.Query
query: core_entities.Query,
) -> entities.RuleJudgeResult:
"""判断消息是否匹配规则
"""
"""判断消息是否匹配规则"""
raise NotImplementedError

View File

@@ -7,21 +7,24 @@ from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("at-bot")
@rule_model.rule_class('at-bot')
class AtBotRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: core_entities.Query
query: core_entities.Query,
) -> entities.RuleJudgeResult:
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
if (
message_chain.has(platform_message.At(query.adapter.bot_account_id))
and rule_dict['at']
):
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
if message_chain.has(platform_message.At(query.adapter.bot_account_id)): # 回复消息时会at两次检查并删除重复的
if message_chain.has(
platform_message.At(query.adapter.bot_account_id)
): # 回复消息时会at两次检查并删除重复的
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
return entities.RuleJudgeResult(
@@ -29,7 +32,4 @@ class AtBotRule(rule_model.GroupRespondRule):
replacement=message_chain,
)
return entities.RuleJudgeResult(
matching=False,
replacement = message_chain
)
return entities.RuleJudgeResult(matching=False, replacement=message_chain)

View File

@@ -1,36 +1,30 @@
from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("prefix")
@rule_model.rule_class('prefix')
class PrefixRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: core_entities.Query
query: core_entities.Query,
) -> entities.RuleJudgeResult:
prefixes = rule_dict['prefix']
for prefix in prefixes:
if message_text.startswith(prefix):
# 查找第一个plain元素
for me in message_chain:
if isinstance(me, platform_message.Plain):
me.text = me.text[len(prefix):]
me.text = me.text[len(prefix) :]
return entities.RuleJudgeResult(
matching=True,
replacement=message_chain,
)
return entities.RuleJudgeResult(
matching=False,
replacement=message_chain
)
return entities.RuleJudgeResult(matching=False, replacement=message_chain)

View File

@@ -7,19 +7,17 @@ from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("random")
@rule_model.rule_class('random')
class RandomRespRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: core_entities.Query
query: core_entities.Query,
) -> entities.RuleJudgeResult:
random_rate = rule_dict['random']
return entities.RuleJudgeResult(
matching=random.random() < random_rate,
replacement=message_chain
)
matching=random.random() < random_rate, replacement=message_chain
)

View File

@@ -7,15 +7,14 @@ from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("regexp")
@rule_model.rule_class('regexp')
class RegExpRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: core_entities.Query
query: core_entities.Query,
) -> entities.RuleJudgeResult:
regexps = rule_dict['regexp']
@@ -27,8 +26,5 @@ class RegExpRule(rule_model.GroupRespondRule):
matching=True,
replacement=message_chain,
)
return entities.RuleJudgeResult(
matching=False,
replacement=message_chain
)
return entities.RuleJudgeResult(matching=False, replacement=message_chain)

View File

@@ -11,17 +11,15 @@ preregistered_stages: dict[str, PipelineStage] = {}
def stage_class(name: str):
def decorator(cls):
preregistered_stages[name] = cls
return cls
return decorator
class PipelineStage(metaclass=abc.ABCMeta):
"""流水线阶段
"""
"""流水线阶段"""
ap: app.Application
@@ -29,8 +27,7 @@ class PipelineStage(metaclass=abc.ABCMeta):
self.ap = ap
async def initialize(self, pipeline_config: dict):
"""初始化
"""
"""初始化"""
pass
@abc.abstractmethod
@@ -42,6 +39,5 @@ class PipelineStage(metaclass=abc.ABCMeta):
entities.StageProcessResult,
typing.AsyncGenerator[entities.StageProcessResult, None],
]:
"""处理
"""
"""处理"""
raise NotImplementedError

View File

@@ -3,21 +3,19 @@ from __future__ import annotations
import typing
from ...core import app, entities as core_entities
from .. import entities
from .. import stage, entities
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from .. import entities
from .. import stage
from ...plugin import events
from ...platform.types import message as platform_message
@stage.stage_class("ResponseWrapper")
@stage.stage_class('ResponseWrapper')
class ResponseWrapper(stage.PipelineStage):
"""回复包装阶段
把回复的 message 包装成人类识读的形式。
改写:
- resp_message_chain
"""
@@ -30,36 +28,36 @@ class ResponseWrapper(stage.PipelineStage):
query: core_entities.Query,
stage_inst_name: str,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
"""处理"""
# 如果 resp_messages[-1] 已经是 MessageChain 了
if isinstance(query.resp_messages[-1], platform_message.MessageChain):
query.resp_message_chain.append(query.resp_messages[-1])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
else:
if query.resp_messages[-1].role == 'command':
query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain(prefix_text='[bot] '))
query.resp_message_chain.append(
query.resp_messages[-1].get_content_platform_message_chain(
prefix_text='[bot] '
)
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
elif query.resp_messages[-1].role == 'plugin':
query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain())
query.resp_message_chain.append(
query.resp_messages[-1].get_content_platform_message_chain()
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
result_type=entities.ResultType.CONTINUE, new_query=query
)
else:
if query.resp_messages[-1].role == 'assistant':
result = query.resp_messages[-1]
session = await self.ap.sess_mgr.get_session(query)
@@ -79,39 +77,51 @@ class ResponseWrapper(stage.PipelineStage):
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
query=query
funcs_called=[
fc.function.name for fc in result.tool_calls
]
if result.tool_calls is not None
else [],
query=query,
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
new_query=query,
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
query.resp_message_chain.append(
platform_message.MessageChain(event_ctx.event.reply)
)
else:
query.resp_message_chain.append(result.get_content_platform_message_chain())
query.resp_message_chain.append(
result.get_content_platform_message_chain()
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
new_query=query,
)
if result.tool_calls is not None and len(result.tool_calls) > 0: # 有函数调用
if (
result.tool_calls is not None and len(result.tool_calls) > 0
): # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
query.resp_message_chain.append(
platform_message.MessageChain(
[platform_message.Plain(reply_text)]
)
)
if query.pipeline_config['output']['misc']['track-function-calls']:
if query.pipeline_config['output']['misc'][
'track-function-calls'
]:
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
@@ -121,26 +131,36 @@ class ResponseWrapper(stage.PipelineStage):
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
query=query
funcs_called=[
fc.function.name for fc in result.tool_calls
]
if result.tool_calls is not None
else [],
query=query,
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
new_query=query,
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
query.resp_message_chain.append(
platform_message.MessageChain(
event_ctx.event.reply
)
)
else:
query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
query.resp_message_chain.append(
platform_message.MessageChain(
[platform_message.Plain(reply_text)]
)
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
new_query=query,
)