style: restrict line-length

This commit is contained in:
Junyan Qin
2025-05-10 18:04:58 +08:00
parent b30016ed08
commit 055b389353
134 changed files with 1096 additions and 2595 deletions
+3 -9
View File
@@ -14,9 +14,7 @@ class BanSessionCheckStage(stage.PipelineStage):
async def initialize(self, pipeline_config: dict):
pass
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
found = False
mode = query.pipeline_config['trigger']['access-control']['mode']
@@ -41,11 +39,7 @@ class BanSessionCheckStage(stage.PipelineStage):
ctn = not found
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE
if ctn
else entities.ResultType.INTERRUPT,
result_type=entities.ResultType.CONTINUE if ctn else entities.ResultType.INTERRUPT,
new_query=query,
console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}'
if not ctn
else '',
console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}' if not ctn else '',
)
+8 -24
View File
@@ -65,9 +65,7 @@ class ContentFilterStage(stage.PipelineStage):
"""
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
for filter in self.filter_chain:
if filter_entities.EnableStage.PRE in filter.enable_stages:
@@ -86,13 +84,9 @@ class ContentFilterStage(stage.PipelineStage):
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
message = result.replacement
query.message_chain = platform_message.MessageChain(
platform_message.Plain(message)
)
query.message_chain = platform_message.MessageChain(platform_message.Plain(message))
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
async def _post_process(
self,
@@ -103,9 +97,7 @@ class ContentFilterStage(stage.PipelineStage):
只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter
"""
if query.pipeline_config['safety']['content-filter']['scope'] == 'income-msg':
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
message = message.strip()
for filter in self.filter_chain:
@@ -127,13 +119,9 @@ class ContentFilterStage(stage.PipelineStage):
query.resp_messages[-1].content = message
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理"""
if stage_inst_name == 'PreContentFilterStage':
contain_non_text = False
@@ -147,9 +135,7 @@ class ContentFilterStage(stage.PipelineStage):
if contain_non_text:
self.ap.logger.debug('消息中包含非文本消息,跳过内容过滤器检查。')
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
return await self._pre_process(str(query.message_chain).strip(), query)
elif stage_inst_name == 'PostContentFilterStage':
@@ -162,8 +148,6 @@ class ContentFilterStage(stage.PipelineStage):
self.ap.logger.debug(
'resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。'
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')
+1 -3
View File
@@ -60,9 +60,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def process(
self, query: core_entities.Query, message: str = None, image_url=None
) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str = None, image_url=None) -> entities.FilterResult:
"""处理消息
分为前后阶段,具体取决于 enable_stages 的值。
@@ -21,19 +21,13 @@ class BaiduCloudExamine(filter_model.ContentFilter):
BAIDU_EXAMINE_TOKEN_URL,
params={
'grant_type': 'client_credentials',
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine'][
'api-key'
],
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine'][
'api-secret'
],
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'],
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret'],
},
) as resp:
return (await resp.json())['access_token']
async def process(
self, query: core_entities.Query, message: str
) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
async with aiohttp.ClientSession() as session:
async with session.post(
BAIDU_EXAMINE_URL.format(await self._get_token()),
+2 -6
View File
@@ -13,9 +13,7 @@ class BanWordFilter(filter_model.ContentFilter):
async def initialize(self):
pass
async def process(
self, query: core_entities.Query, message: str
) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
found = False
for word in self.ap.sensitive_meta.data['words']:
@@ -31,9 +29,7 @@ class BanWordFilter(filter_model.ContentFilter):
self.ap.sensitive_meta.data['mask'] * len(match[i]),
)
else:
message = message.replace(
match[i], self.ap.sensitive_meta.data['mask_word']
)
message = message.replace(match[i], self.ap.sensitive_meta.data['mask_word'])
return entities.FilterResult(
level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS,
+1 -3
View File
@@ -16,9 +16,7 @@ class ContentIgnore(filter_model.ContentFilter):
entities.EnableStage.PRE,
]
async def process(
self, query: core_entities.Query, message: str
) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
if message.startswith(rule):
+6 -16
View File
@@ -16,9 +16,7 @@ class Controller:
def __init__(self, ap: app.Application):
self.ap = ap
self.semaphore = asyncio.Semaphore(
self.ap.instance_config.data['concurrency']['pipeline']
)
self.semaphore = asyncio.Semaphore(self.ap.instance_config.data['concurrency']['pipeline'])
async def consumer(self):
"""事件处理循环"""
@@ -32,9 +30,7 @@ class Controller:
for query in queries:
session = await self.ap.sess_mgr.get_session(query)
self.ap.logger.debug(
f'Checking query {query} session {session}'
)
self.ap.logger.debug(f'Checking query {query} session {session}')
if not session.semaphore.locked():
selected_query = query
@@ -55,22 +51,16 @@ class Controller:
# find pipeline
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
# Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected.
bot = await self.ap.platform_mgr.get_bot_by_uuid(
selected_query.bot_uuid
)
bot = await self.ap.platform_mgr.get_bot_by_uuid(selected_query.bot_uuid)
if bot:
pipeline = (
await self.ap.pipeline_mgr.get_pipeline_by_uuid(
bot.bot_entity.use_pipeline_uuid
)
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(
bot.bot_entity.use_pipeline_uuid
)
if pipeline:
await pipeline.run(selected_query)
async with self.ap.query_pool:
(
await self.ap.sess_mgr.get_session(selected_query)
).semaphore.release()
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
# 通知其他协程,有新的请求可以处理了
self.ap.query_pool.condition.notify_all()
+5 -15
View File
@@ -47,9 +47,7 @@ class LongTextProcessStage(stage.PipelineStage):
'未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。'
)
pipeline_config['output']['long-text-processing'][
'strategy'
] = 'forward'
pipeline_config['output']['long-text-processing']['strategy'] = 'forward'
except Exception:
traceback.print_exc()
self.ap.logger.error(
@@ -58,9 +56,7 @@ class LongTextProcessStage(stage.PipelineStage):
)
)
pipeline_config['output']['long-text-processing']['strategy'] = (
'forward'
)
pipeline_config['output']['long-text-processing']['strategy'] = 'forward'
for strategy_cls in strategy.preregistered_strategies:
if strategy_cls.name == config['strategy']:
@@ -71,9 +67,7 @@ class LongTextProcessStage(stage.PipelineStage):
await self.strategy_impl.initialize()
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
# 检查是否包含非 Plain 组件
contains_non_plain = False
@@ -89,11 +83,7 @@ class LongTextProcessStage(stage.PipelineStage):
> query.pipeline_config['output']['long-text-processing']['threshold']
):
query.resp_message_chain[-1] = platform_message.MessageChain(
await self.strategy_impl.process(
str(query.resp_message_chain[-1]), query
)
await self.strategy_impl.process(str(query.resp_message_chain[-1]), query)
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
+1 -3
View File
@@ -13,9 +13,7 @@ Forward = platform_message.Forward
@strategy_model.strategy_class('forward')
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(
self, message: str, query: core_entities.Query
) -> list[platform_message.MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
display = ForwardMessageDiaplay(
title='群聊的聊天记录',
brief='[聊天记录]',
+4 -13
View File
@@ -27,18 +27,14 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
encoding='utf-8',
)
async def process(
self, message: str, query: core_entities.Query
) -> list[platform_message.MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
img_path = self.text_to_image(
text_str=message,
save_as='temp/{}.png'.format(int(time.time())),
query=query,
)
compressed_path, size = self.compress_image(
img_path, outfile='temp/{}_compressed.png'.format(int(time.time()))
)
compressed_path, size = self.compress_image(img_path, outfile='temp/{}_compressed.png'.format(int(time.time())))
with open(compressed_path, 'rb') as f:
img = f.read()
@@ -165,10 +161,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
numbers = self.indexNumber(rest_text)
for number in numbers:
if (
number[1] < point < number[1] + len(number[0])
and number[1] != 0
):
if number[1] < point < number[1] + len(number[0]) and number[1] != 0:
point = number[1]
break
@@ -181,9 +174,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
else:
continue
# 准备画布
img = Image.new(
'RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255)
)
img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255))
draw = ImageDraw.Draw(img, mode='RGBA')
self.ap.logger.debug('正在绘制图片...')
+1 -3
View File
@@ -49,9 +49,7 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def process(
self, message: str, query: core_entities.Query
) -> list[platform_message.MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
"""处理长文本
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法
+2 -6
View File
@@ -29,12 +29,8 @@ class ConversationMessageTruncator(stage.PipelineStage):
else:
raise ValueError(f'未知的截断器: {use_method}')
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理"""
query = await self.trun.truncate(query)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
+15 -47
View File
@@ -79,26 +79,20 @@ class RuntimePipeline:
query.pipeline_config = self.pipeline_entity.config
await self.process_query(query)
async def _check_output(
self, query: entities.Query, result: pipeline_entities.StageProcessResult
):
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
"""检查输出"""
if result.user_notice:
# 处理str类型
if isinstance(result.user_notice, str):
result.user_notice = platform_message.MessageChain(
platform_message.Plain(result.user_notice)
)
result.user_notice = platform_message.MessageChain(platform_message.Plain(result.user_notice))
elif isinstance(result.user_notice, list):
result.user_notice = platform_message.MessageChain(*result.user_notice)
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
query.message_event, platform_events.GroupMessage
):
result.user_notice.insert(
0, platform_message.At(query.message_event.sender.id)
)
result.user_notice.insert(0, platform_message.At(query.message_event.sender.id))
await query.adapter.reply_message(
message_source=query.message_event,
@@ -150,37 +144,25 @@ class RuntimePipeline:
result = await result
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
self.ap.logger.debug(
f'Stage {stage_container.inst_name} processed query {query} res {result}'
)
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} res {result}')
await self._check_output(query, result)
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(
f'Stage {stage_container.inst_name} interrupted query {query}'
)
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query}')
break
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
query = result.new_query
elif isinstance(result, typing.AsyncGenerator): # 生成器
self.ap.logger.debug(
f'Stage {stage_container.inst_name} processed query {query} gen'
)
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} gen')
async for sub_result in result:
self.ap.logger.debug(
f'Stage {stage_container.inst_name} processed query {query} res {sub_result}'
)
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} res {sub_result}')
await self._check_output(query, sub_result)
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(
f'Stage {stage_container.inst_name} interrupted query {query}'
)
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query}')
break
elif (
sub_result.result_type == pipeline_entities.ResultType.CONTINUE
):
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
query = sub_result.new_query
await self._execute_from_stage(i + 1, query)
break
@@ -214,12 +196,8 @@ class RuntimePipeline:
await self._execute_from_stage(0, query)
except Exception as e:
inst_name = (
query.current_stage.inst_name if query.current_stage else 'unknown'
)
self.ap.logger.error(
f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}'
)
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}')
self.ap.logger.debug(f'Traceback: {traceback.format_exc()}')
finally:
self.ap.logger.debug(f'Query {query} processed')
@@ -241,18 +219,14 @@ class PipelineManager:
self.pipelines = []
async def initialize(self):
self.stage_dict = {
name: cls for name, cls in stage.preregistered_stages.items()
}
self.stage_dict = {name: cls for name, cls in stage.preregistered_stages.items()}
await self.load_pipelines_from_db()
async def load_pipelines_from_db(self):
self.ap.logger.info('Loading pipelines from db...')
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline)
)
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
pipelines = result.all()
@@ -267,20 +241,14 @@ class PipelineManager:
| dict,
):
if isinstance(pipeline_entity, sqlalchemy.Row):
pipeline_entity = persistence_pipeline.LegacyPipeline(
**pipeline_entity._mapping
)
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity._mapping)
elif isinstance(pipeline_entity, dict):
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
# initialize stage containers according to pipeline_entity.stages
stage_containers: list[StageInstContainer] = []
for stage_name in pipeline_entity.stages:
stage_containers.append(
StageInstContainer(
inst_name=stage_name, inst=self.stage_dict[stage_name](self.ap)
)
)
stage_containers.append(StageInstContainer(inst_name=stage_name, inst=self.stage_dict[stage_name](self.ap)))
for stage_container in stage_containers:
await stage_container.inst.initialize(pipeline_entity.config)
+9 -17
View File
@@ -44,9 +44,7 @@ class PreProcessor(stage.PipelineStage):
query.use_llm_model = conversation.use_llm_model
query.use_funcs = (
conversation.use_funcs
if query.use_llm_model.model_entity.abilities.__contains__('tool_call')
else None
conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('tool_call') else None
)
query.variables = {
@@ -59,10 +57,9 @@ class PreProcessor(stage.PipelineStage):
# Check if this model supports vision, if not, remove all images
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
if (
query.pipeline_config['ai']['runner']['runner'] == 'local-agent'
and not query.use_llm_model.model_entity.abilities.__contains__('vision')
):
if query.pipeline_config['ai']['runner'][
'runner'
] == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'):
for msg in query.messages:
if isinstance(msg.content, list):
for me in msg.content:
@@ -78,14 +75,11 @@ class PreProcessor(stage.PipelineStage):
content_list.append(llm_entities.ContentElement.from_text(me.text))
plain_text += me.text
elif isinstance(me, platform_message.Image):
if (
query.pipeline_config['ai']['runner']['runner'] != 'local-agent'
or query.use_llm_model.model_entity.abilities.__contains__('vision')
):
if query.pipeline_config['ai']['runner'][
'runner'
] != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__('vision'):
if me.base64 is not None:
content_list.append(
llm_entities.ContentElement.from_image_base64(me.base64)
)
content_list.append(llm_entities.ContentElement.from_image_base64(me.base64))
query.variables['user_message_text'] = plain_text
@@ -104,6 +98,4 @@ class PreProcessor(stage.PipelineStage):
query.prompt.messages = event_ctx.event.default_prompt
query.messages = event_ctx.event.prompt
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
+7 -21
View File
@@ -49,13 +49,9 @@ class ChatMessageHandler(handler.MessageHandler):
query.resp_messages.append(mc)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
else:
if event_ctx.event.alter is not None:
# if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
@@ -69,34 +65,24 @@ class ChatMessageHandler(handler.MessageHandler):
runner = r(self.ap, query.pipeline_config)
break
else:
raise ValueError(
f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}'
)
raise ValueError(f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}')
async for result in runner.run(query):
query.resp_messages.append(result)
self.ap.logger.info(
f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}'
)
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
if result.content is not None:
text_length += len(result.content)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
query.session.using_conversation.messages.append(query.user_message)
query.session.using_conversation.messages.extend(query.resp_messages)
except Exception as e:
self.ap.logger.error(
f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}'
)
self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}')
hide_exception_info = query.pipeline_config['output']['misc'][
'hide-exception'
]
hide_exception_info = query.pipeline_config['output']['misc']['hide-exception']
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
+10 -31
View File
@@ -21,10 +21,7 @@ class CommandHandler(handler.MessageHandler):
privilege = 1
if (
f'{query.launcher_type.value}_{query.launcher_id}'
in self.ap.instance_config.data['admins']
):
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
privilege = 2
spt = command_text.split(' ')
@@ -54,25 +51,17 @@ class CommandHandler(handler.MessageHandler):
query.resp_messages.append(mc)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
else:
if event_ctx.event.alter is not None:
query.message_chain = platform_message.MessageChain(
[platform_message.Plain(event_ctx.event.alter)]
)
query.message_chain = platform_message.MessageChain([platform_message.Plain(event_ctx.event.alter)])
session = await self.ap.sess_mgr.get_session(query)
async for ret in self.ap.cmd_mgr.execute(
command_text=command_text, query=query, session=session
):
async for ret in self.ap.cmd_mgr.execute(command_text=command_text, query=query, session=session):
if ret.error is not None:
query.resp_messages.append(
llm_entities.Message(
@@ -81,13 +70,9 @@ class CommandHandler(handler.MessageHandler):
)
)
self.ap.logger.info(
f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}'
)
self.ap.logger.info(f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}')
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
elif ret.text is not None or ret.image_url is not None:
content: list[llm_entities.ContentElement] = []
@@ -95,9 +80,7 @@ class CommandHandler(handler.MessageHandler):
content.append(llm_entities.ContentElement.from_text(ret.text))
if ret.image_url is not None:
content.append(
llm_entities.ContentElement.from_image_url(ret.image_url)
)
content.append(llm_entities.ContentElement.from_image_url(ret.image_url))
query.resp_messages.append(
llm_entities.Message(
@@ -108,10 +91,6 @@ class CommandHandler(handler.MessageHandler):
self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}')
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
+1 -3
View File
@@ -72,9 +72,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
if count >= limitation:
if query.pipeline_config['safety']['rate-limit']['strategy'] == 'drop':
return False
elif (
query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait'
):
elif query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait':
# 等待下一窗口
await asyncio.sleep(window_size - time.time() % window_size)
+3 -9
View File
@@ -15,9 +15,7 @@ from ...core import entities as core_entities
class SendResponseBackStage(stage.PipelineStage):
"""发送响应消息"""
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理"""
random_range = (
@@ -34,9 +32,7 @@ class SendResponseBackStage(stage.PipelineStage):
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
query.message_event, platform_events.GroupMessage
):
query.resp_message_chain[-1].insert(
0, platform_message.At(query.message_event.sender.id)
)
query.resp_message_chain[-1].insert(0, platform_message.At(query.message_event.sender.id))
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
@@ -46,6 +42,4 @@ class SendResponseBackStage(stage.PipelineStage):
quote_origin=quote_origin,
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
+4 -12
View File
@@ -32,13 +32,9 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
await rule_inst.initialize()
self.rule_matchers.append(rule_inst)
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if query.launcher_type.value != 'group': # 只处理群消息
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
rules = query.pipeline_config['trigger']['group-respond-rules']
@@ -49,9 +45,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
# use_rule = rules[str(query.launcher_id)]
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
res = await rule_matcher.match(
str(query.message_chain), query.message_chain, use_rule, query
)
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)
if res.matching:
query.message_chain = res.replacement
@@ -60,6 +54,4 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
new_query=query,
)
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
+1 -4
View File
@@ -16,10 +16,7 @@ class AtBotRule(rule_model.GroupRespondRule):
rule_dict: dict,
query: core_entities.Query,
) -> entities.RuleJudgeResult:
if (
message_chain.has(platform_message.At(query.adapter.bot_account_id))
and rule_dict['at']
):
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
if message_chain.has(
+1 -3
View File
@@ -18,6 +18,4 @@ class RandomRespRule(rule_model.GroupRespondRule):
) -> entities.RuleJudgeResult:
random_rate = rule_dict['random']
return entities.RuleJudgeResult(
matching=random.random() < random_rate, replacement=message_chain
)
return entities.RuleJudgeResult(matching=random.random() < random_rate, replacement=message_chain)
+14 -42
View File
@@ -34,29 +34,19 @@ class ResponseWrapper(stage.PipelineStage):
if isinstance(query.resp_messages[-1], platform_message.MessageChain):
query.resp_message_chain.append(query.resp_messages[-1])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
if query.resp_messages[-1].role == 'command':
query.resp_message_chain.append(
query.resp_messages[-1].get_content_platform_message_chain(
prefix_text='[bot] '
)
query.resp_messages[-1].get_content_platform_message_chain(prefix_text='[bot] ')
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
elif query.resp_messages[-1].role == 'plugin':
query.resp_message_chain.append(
query.resp_messages[-1].get_content_platform_message_chain()
)
query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain())
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
if query.resp_messages[-1].role == 'assistant':
result = query.resp_messages[-1]
@@ -77,9 +67,7 @@ class ResponseWrapper(stage.PipelineStage):
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[
fc.function.name for fc in result.tool_calls
]
funcs_called=[fc.function.name for fc in result.tool_calls]
if result.tool_calls is not None
else [],
query=query,
@@ -92,36 +80,26 @@ class ResponseWrapper(stage.PipelineStage):
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(
platform_message.MessageChain(event_ctx.event.reply)
)
query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
else:
query.resp_message_chain.append(
result.get_content_platform_message_chain()
)
query.resp_message_chain.append(result.get_content_platform_message_chain())
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query,
)
if (
result.tool_calls is not None and len(result.tool_calls) > 0
): # 有函数调用
if result.tool_calls is not None and len(result.tool_calls) > 0: # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain.append(
platform_message.MessageChain(
[platform_message.Plain(reply_text)]
)
platform_message.MessageChain([platform_message.Plain(reply_text)])
)
if query.pipeline_config['output']['misc'][
'track-function-calls'
]:
if query.pipeline_config['output']['misc']['track-function-calls']:
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
@@ -131,9 +109,7 @@ class ResponseWrapper(stage.PipelineStage):
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[
fc.function.name for fc in result.tool_calls
]
funcs_called=[fc.function.name for fc in result.tool_calls]
if result.tool_calls is not None
else [],
query=query,
@@ -148,16 +124,12 @@ class ResponseWrapper(stage.PipelineStage):
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(
platform_message.MessageChain(
event_ctx.event.reply
)
platform_message.MessageChain(event_ctx.event.reply)
)
else:
query.resp_message_chain.append(
platform_message.MessageChain(
[platform_message.Plain(reply_text)]
)
platform_message.MessageChain([platform_message.Plain(reply_text)])
)
yield entities.StageProcessResult(