mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-26 15:34:26 +00:00
style: restrict line-length
This commit is contained in:
@@ -14,9 +14,7 @@ class BanSessionCheckStage(stage.PipelineStage):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
pass
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
found = False
|
||||
|
||||
mode = query.pipeline_config['trigger']['access-control']['mode']
|
||||
@@ -41,11 +39,7 @@ class BanSessionCheckStage(stage.PipelineStage):
|
||||
ctn = not found
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE
|
||||
if ctn
|
||||
else entities.ResultType.INTERRUPT,
|
||||
result_type=entities.ResultType.CONTINUE if ctn else entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}'
|
||||
if not ctn
|
||||
else '',
|
||||
console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}' if not ctn else '',
|
||||
)
|
||||
|
||||
@@ -65,9 +65,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
"""
|
||||
|
||||
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
for filter in self.filter_chain:
|
||||
if filter_entities.EnableStage.PRE in filter.enable_stages:
|
||||
@@ -86,13 +84,9 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
|
||||
message = result.replacement
|
||||
|
||||
query.message_chain = platform_message.MessageChain(
|
||||
platform_message.Plain(message)
|
||||
)
|
||||
query.message_chain = platform_message.MessageChain(platform_message.Plain(message))
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
async def _post_process(
|
||||
self,
|
||||
@@ -103,9 +97,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter
|
||||
"""
|
||||
if query.pipeline_config['safety']['content-filter']['scope'] == 'income-msg':
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
message = message.strip()
|
||||
for filter in self.filter_chain:
|
||||
@@ -127,13 +119,9 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
|
||||
query.resp_messages[-1].content = message
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理"""
|
||||
if stage_inst_name == 'PreContentFilterStage':
|
||||
contain_non_text = False
|
||||
@@ -147,9 +135,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
|
||||
if contain_non_text:
|
||||
self.ap.logger.debug('消息中包含非文本消息,跳过内容过滤器检查。')
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
return await self._pre_process(str(query.message_chain).strip(), query)
|
||||
elif stage_inst_name == 'PostContentFilterStage':
|
||||
@@ -162,8 +148,6 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
self.ap.logger.debug(
|
||||
'resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。'
|
||||
)
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')
|
||||
|
||||
@@ -60,9 +60,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(
|
||||
self, query: core_entities.Query, message: str = None, image_url=None
|
||||
) -> entities.FilterResult:
|
||||
async def process(self, query: core_entities.Query, message: str = None, image_url=None) -> entities.FilterResult:
|
||||
"""处理消息
|
||||
|
||||
分为前后阶段,具体取决于 enable_stages 的值。
|
||||
|
||||
@@ -21,19 +21,13 @@ class BaiduCloudExamine(filter_model.ContentFilter):
|
||||
BAIDU_EXAMINE_TOKEN_URL,
|
||||
params={
|
||||
'grant_type': 'client_credentials',
|
||||
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine'][
|
||||
'api-key'
|
||||
],
|
||||
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine'][
|
||||
'api-secret'
|
||||
],
|
||||
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'],
|
||||
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret'],
|
||||
},
|
||||
) as resp:
|
||||
return (await resp.json())['access_token']
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, message: str
|
||||
) -> entities.FilterResult:
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
BAIDU_EXAMINE_URL.format(await self._get_token()),
|
||||
|
||||
@@ -13,9 +13,7 @@ class BanWordFilter(filter_model.ContentFilter):
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, message: str
|
||||
) -> entities.FilterResult:
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
found = False
|
||||
|
||||
for word in self.ap.sensitive_meta.data['words']:
|
||||
@@ -31,9 +29,7 @@ class BanWordFilter(filter_model.ContentFilter):
|
||||
self.ap.sensitive_meta.data['mask'] * len(match[i]),
|
||||
)
|
||||
else:
|
||||
message = message.replace(
|
||||
match[i], self.ap.sensitive_meta.data['mask_word']
|
||||
)
|
||||
message = message.replace(match[i], self.ap.sensitive_meta.data['mask_word'])
|
||||
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS,
|
||||
|
||||
@@ -16,9 +16,7 @@ class ContentIgnore(filter_model.ContentFilter):
|
||||
entities.EnableStage.PRE,
|
||||
]
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, message: str
|
||||
) -> entities.FilterResult:
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
|
||||
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
|
||||
if message.startswith(rule):
|
||||
|
||||
@@ -16,9 +16,7 @@ class Controller:
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.semaphore = asyncio.Semaphore(
|
||||
self.ap.instance_config.data['concurrency']['pipeline']
|
||||
)
|
||||
self.semaphore = asyncio.Semaphore(self.ap.instance_config.data['concurrency']['pipeline'])
|
||||
|
||||
async def consumer(self):
|
||||
"""事件处理循环"""
|
||||
@@ -32,9 +30,7 @@ class Controller:
|
||||
|
||||
for query in queries:
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
self.ap.logger.debug(
|
||||
f'Checking query {query} session {session}'
|
||||
)
|
||||
self.ap.logger.debug(f'Checking query {query} session {session}')
|
||||
|
||||
if not session.semaphore.locked():
|
||||
selected_query = query
|
||||
@@ -55,22 +51,16 @@ class Controller:
|
||||
# find pipeline
|
||||
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
|
||||
# Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected.
|
||||
bot = await self.ap.platform_mgr.get_bot_by_uuid(
|
||||
selected_query.bot_uuid
|
||||
)
|
||||
bot = await self.ap.platform_mgr.get_bot_by_uuid(selected_query.bot_uuid)
|
||||
if bot:
|
||||
pipeline = (
|
||||
await self.ap.pipeline_mgr.get_pipeline_by_uuid(
|
||||
bot.bot_entity.use_pipeline_uuid
|
||||
)
|
||||
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(
|
||||
bot.bot_entity.use_pipeline_uuid
|
||||
)
|
||||
if pipeline:
|
||||
await pipeline.run(selected_query)
|
||||
|
||||
async with self.ap.query_pool:
|
||||
(
|
||||
await self.ap.sess_mgr.get_session(selected_query)
|
||||
).semaphore.release()
|
||||
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
|
||||
# 通知其他协程,有新的请求可以处理了
|
||||
self.ap.query_pool.condition.notify_all()
|
||||
|
||||
|
||||
@@ -47,9 +47,7 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
'未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。'
|
||||
)
|
||||
|
||||
pipeline_config['output']['long-text-processing'][
|
||||
'strategy'
|
||||
] = 'forward'
|
||||
pipeline_config['output']['long-text-processing']['strategy'] = 'forward'
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
self.ap.logger.error(
|
||||
@@ -58,9 +56,7 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
)
|
||||
)
|
||||
|
||||
pipeline_config['output']['long-text-processing']['strategy'] = (
|
||||
'forward'
|
||||
)
|
||||
pipeline_config['output']['long-text-processing']['strategy'] = 'forward'
|
||||
|
||||
for strategy_cls in strategy.preregistered_strategies:
|
||||
if strategy_cls.name == config['strategy']:
|
||||
@@ -71,9 +67,7 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
|
||||
await self.strategy_impl.initialize()
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
# 检查是否包含非 Plain 组件
|
||||
contains_non_plain = False
|
||||
|
||||
@@ -89,11 +83,7 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
> query.pipeline_config['output']['long-text-processing']['threshold']
|
||||
):
|
||||
query.resp_message_chain[-1] = platform_message.MessageChain(
|
||||
await self.strategy_impl.process(
|
||||
str(query.resp_message_chain[-1]), query
|
||||
)
|
||||
await self.strategy_impl.process(str(query.resp_message_chain[-1]), query)
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
@@ -13,9 +13,7 @@ Forward = platform_message.Forward
|
||||
|
||||
@strategy_model.strategy_class('forward')
|
||||
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
|
||||
async def process(
|
||||
self, message: str, query: core_entities.Query
|
||||
) -> list[platform_message.MessageComponent]:
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
display = ForwardMessageDiaplay(
|
||||
title='群聊的聊天记录',
|
||||
brief='[聊天记录]',
|
||||
|
||||
@@ -27,18 +27,14 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
encoding='utf-8',
|
||||
)
|
||||
|
||||
async def process(
|
||||
self, message: str, query: core_entities.Query
|
||||
) -> list[platform_message.MessageComponent]:
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
img_path = self.text_to_image(
|
||||
text_str=message,
|
||||
save_as='temp/{}.png'.format(int(time.time())),
|
||||
query=query,
|
||||
)
|
||||
|
||||
compressed_path, size = self.compress_image(
|
||||
img_path, outfile='temp/{}_compressed.png'.format(int(time.time()))
|
||||
)
|
||||
compressed_path, size = self.compress_image(img_path, outfile='temp/{}_compressed.png'.format(int(time.time())))
|
||||
|
||||
with open(compressed_path, 'rb') as f:
|
||||
img = f.read()
|
||||
@@ -165,10 +161,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
numbers = self.indexNumber(rest_text)
|
||||
|
||||
for number in numbers:
|
||||
if (
|
||||
number[1] < point < number[1] + len(number[0])
|
||||
and number[1] != 0
|
||||
):
|
||||
if number[1] < point < number[1] + len(number[0]) and number[1] != 0:
|
||||
point = number[1]
|
||||
break
|
||||
|
||||
@@ -181,9 +174,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
else:
|
||||
continue
|
||||
# 准备画布
|
||||
img = Image.new(
|
||||
'RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255)
|
||||
)
|
||||
img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255))
|
||||
draw = ImageDraw.Draw(img, mode='RGBA')
|
||||
|
||||
self.ap.logger.debug('正在绘制图片...')
|
||||
|
||||
@@ -49,9 +49,7 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(
|
||||
self, message: str, query: core_entities.Query
|
||||
) -> list[platform_message.MessageComponent]:
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
"""处理长文本
|
||||
|
||||
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法
|
||||
|
||||
@@ -29,12 +29,8 @@ class ConversationMessageTruncator(stage.PipelineStage):
|
||||
else:
|
||||
raise ValueError(f'未知的截断器: {use_method}')
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理"""
|
||||
query = await self.trun.truncate(query)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
+15
-47
@@ -79,26 +79,20 @@ class RuntimePipeline:
|
||||
query.pipeline_config = self.pipeline_entity.config
|
||||
await self.process_query(query)
|
||||
|
||||
async def _check_output(
|
||||
self, query: entities.Query, result: pipeline_entities.StageProcessResult
|
||||
):
|
||||
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
|
||||
"""检查输出"""
|
||||
if result.user_notice:
|
||||
# 处理str类型
|
||||
|
||||
if isinstance(result.user_notice, str):
|
||||
result.user_notice = platform_message.MessageChain(
|
||||
platform_message.Plain(result.user_notice)
|
||||
)
|
||||
result.user_notice = platform_message.MessageChain(platform_message.Plain(result.user_notice))
|
||||
elif isinstance(result.user_notice, list):
|
||||
result.user_notice = platform_message.MessageChain(*result.user_notice)
|
||||
|
||||
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
|
||||
query.message_event, platform_events.GroupMessage
|
||||
):
|
||||
result.user_notice.insert(
|
||||
0, platform_message.At(query.message_event.sender.id)
|
||||
)
|
||||
result.user_notice.insert(0, platform_message.At(query.message_event.sender.id))
|
||||
|
||||
await query.adapter.reply_message(
|
||||
message_source=query.message_event,
|
||||
@@ -150,37 +144,25 @@ class RuntimePipeline:
|
||||
result = await result
|
||||
|
||||
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} processed query {query} res {result}'
|
||||
)
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} res {result}')
|
||||
await self._check_output(query, result)
|
||||
|
||||
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} interrupted query {query}'
|
||||
)
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query}')
|
||||
break
|
||||
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = result.new_query
|
||||
elif isinstance(result, typing.AsyncGenerator): # 生成器
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} processed query {query} gen'
|
||||
)
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} gen')
|
||||
|
||||
async for sub_result in result:
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} processed query {query} res {sub_result}'
|
||||
)
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} res {sub_result}')
|
||||
await self._check_output(query, sub_result)
|
||||
|
||||
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} interrupted query {query}'
|
||||
)
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query}')
|
||||
break
|
||||
elif (
|
||||
sub_result.result_type == pipeline_entities.ResultType.CONTINUE
|
||||
):
|
||||
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = sub_result.new_query
|
||||
await self._execute_from_stage(i + 1, query)
|
||||
break
|
||||
@@ -214,12 +196,8 @@ class RuntimePipeline:
|
||||
|
||||
await self._execute_from_stage(0, query)
|
||||
except Exception as e:
|
||||
inst_name = (
|
||||
query.current_stage.inst_name if query.current_stage else 'unknown'
|
||||
)
|
||||
self.ap.logger.error(
|
||||
f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}'
|
||||
)
|
||||
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
|
||||
self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}')
|
||||
self.ap.logger.debug(f'Traceback: {traceback.format_exc()}')
|
||||
finally:
|
||||
self.ap.logger.debug(f'Query {query} processed')
|
||||
@@ -241,18 +219,14 @@ class PipelineManager:
|
||||
self.pipelines = []
|
||||
|
||||
async def initialize(self):
|
||||
self.stage_dict = {
|
||||
name: cls for name, cls in stage.preregistered_stages.items()
|
||||
}
|
||||
self.stage_dict = {name: cls for name, cls in stage.preregistered_stages.items()}
|
||||
|
||||
await self.load_pipelines_from_db()
|
||||
|
||||
async def load_pipelines_from_db(self):
|
||||
self.ap.logger.info('Loading pipelines from db...')
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline)
|
||||
)
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
|
||||
|
||||
pipelines = result.all()
|
||||
|
||||
@@ -267,20 +241,14 @@ class PipelineManager:
|
||||
| dict,
|
||||
):
|
||||
if isinstance(pipeline_entity, sqlalchemy.Row):
|
||||
pipeline_entity = persistence_pipeline.LegacyPipeline(
|
||||
**pipeline_entity._mapping
|
||||
)
|
||||
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity._mapping)
|
||||
elif isinstance(pipeline_entity, dict):
|
||||
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
|
||||
|
||||
# initialize stage containers according to pipeline_entity.stages
|
||||
stage_containers: list[StageInstContainer] = []
|
||||
for stage_name in pipeline_entity.stages:
|
||||
stage_containers.append(
|
||||
StageInstContainer(
|
||||
inst_name=stage_name, inst=self.stage_dict[stage_name](self.ap)
|
||||
)
|
||||
)
|
||||
stage_containers.append(StageInstContainer(inst_name=stage_name, inst=self.stage_dict[stage_name](self.ap)))
|
||||
|
||||
for stage_container in stage_containers:
|
||||
await stage_container.inst.initialize(pipeline_entity.config)
|
||||
|
||||
@@ -44,9 +44,7 @@ class PreProcessor(stage.PipelineStage):
|
||||
query.use_llm_model = conversation.use_llm_model
|
||||
|
||||
query.use_funcs = (
|
||||
conversation.use_funcs
|
||||
if query.use_llm_model.model_entity.abilities.__contains__('tool_call')
|
||||
else None
|
||||
conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('tool_call') else None
|
||||
)
|
||||
|
||||
query.variables = {
|
||||
@@ -59,10 +57,9 @@ class PreProcessor(stage.PipelineStage):
|
||||
|
||||
# Check if this model supports vision, if not, remove all images
|
||||
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
|
||||
if (
|
||||
query.pipeline_config['ai']['runner']['runner'] == 'local-agent'
|
||||
and not query.use_llm_model.model_entity.abilities.__contains__('vision')
|
||||
):
|
||||
if query.pipeline_config['ai']['runner'][
|
||||
'runner'
|
||||
] == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'):
|
||||
for msg in query.messages:
|
||||
if isinstance(msg.content, list):
|
||||
for me in msg.content:
|
||||
@@ -78,14 +75,11 @@ class PreProcessor(stage.PipelineStage):
|
||||
content_list.append(llm_entities.ContentElement.from_text(me.text))
|
||||
plain_text += me.text
|
||||
elif isinstance(me, platform_message.Image):
|
||||
if (
|
||||
query.pipeline_config['ai']['runner']['runner'] != 'local-agent'
|
||||
or query.use_llm_model.model_entity.abilities.__contains__('vision')
|
||||
):
|
||||
if query.pipeline_config['ai']['runner'][
|
||||
'runner'
|
||||
] != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__('vision'):
|
||||
if me.base64 is not None:
|
||||
content_list.append(
|
||||
llm_entities.ContentElement.from_image_base64(me.base64)
|
||||
)
|
||||
content_list.append(llm_entities.ContentElement.from_image_base64(me.base64))
|
||||
|
||||
query.variables['user_message_text'] = plain_text
|
||||
|
||||
@@ -104,6 +98,4 @@ class PreProcessor(stage.PipelineStage):
|
||||
query.prompt.messages = event_ctx.event.default_prompt
|
||||
query.messages = event_ctx.event.prompt
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
@@ -49,13 +49,9 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
|
||||
query.resp_messages.append(mc)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
|
||||
else:
|
||||
if event_ctx.event.alter is not None:
|
||||
# if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
|
||||
@@ -69,34 +65,24 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
runner = r(self.ap, query.pipeline_config)
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}'
|
||||
)
|
||||
raise ValueError(f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}')
|
||||
|
||||
async for result in runner.run(query):
|
||||
query.resp_messages.append(result)
|
||||
|
||||
self.ap.logger.info(
|
||||
f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}'
|
||||
)
|
||||
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
|
||||
|
||||
if result.content is not None:
|
||||
text_length += len(result.content)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
query.session.using_conversation.messages.append(query.user_message)
|
||||
query.session.using_conversation.messages.extend(query.resp_messages)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(
|
||||
f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}'
|
||||
)
|
||||
self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}')
|
||||
|
||||
hide_exception_info = query.pipeline_config['output']['misc'][
|
||||
'hide-exception'
|
||||
]
|
||||
hide_exception_info = query.pipeline_config['output']['misc']['hide-exception']
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
|
||||
@@ -21,10 +21,7 @@ class CommandHandler(handler.MessageHandler):
|
||||
|
||||
privilege = 1
|
||||
|
||||
if (
|
||||
f'{query.launcher_type.value}_{query.launcher_id}'
|
||||
in self.ap.instance_config.data['admins']
|
||||
):
|
||||
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
|
||||
privilege = 2
|
||||
|
||||
spt = command_text.split(' ')
|
||||
@@ -54,25 +51,17 @@ class CommandHandler(handler.MessageHandler):
|
||||
|
||||
query.resp_messages.append(mc)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
|
||||
|
||||
else:
|
||||
if event_ctx.event.alter is not None:
|
||||
query.message_chain = platform_message.MessageChain(
|
||||
[platform_message.Plain(event_ctx.event.alter)]
|
||||
)
|
||||
query.message_chain = platform_message.MessageChain([platform_message.Plain(event_ctx.event.alter)])
|
||||
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
async for ret in self.ap.cmd_mgr.execute(
|
||||
command_text=command_text, query=query, session=session
|
||||
):
|
||||
async for ret in self.ap.cmd_mgr.execute(command_text=command_text, query=query, session=session):
|
||||
if ret.error is not None:
|
||||
query.resp_messages.append(
|
||||
llm_entities.Message(
|
||||
@@ -81,13 +70,9 @@ class CommandHandler(handler.MessageHandler):
|
||||
)
|
||||
)
|
||||
|
||||
self.ap.logger.info(
|
||||
f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}'
|
||||
)
|
||||
self.ap.logger.info(f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}')
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
elif ret.text is not None or ret.image_url is not None:
|
||||
content: list[llm_entities.ContentElement] = []
|
||||
|
||||
@@ -95,9 +80,7 @@ class CommandHandler(handler.MessageHandler):
|
||||
content.append(llm_entities.ContentElement.from_text(ret.text))
|
||||
|
||||
if ret.image_url is not None:
|
||||
content.append(
|
||||
llm_entities.ContentElement.from_image_url(ret.image_url)
|
||||
)
|
||||
content.append(llm_entities.ContentElement.from_image_url(ret.image_url))
|
||||
|
||||
query.resp_messages.append(
|
||||
llm_entities.Message(
|
||||
@@ -108,10 +91,6 @@ class CommandHandler(handler.MessageHandler):
|
||||
|
||||
self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}')
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
|
||||
|
||||
@@ -72,9 +72,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
if count >= limitation:
|
||||
if query.pipeline_config['safety']['rate-limit']['strategy'] == 'drop':
|
||||
return False
|
||||
elif (
|
||||
query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait'
|
||||
):
|
||||
elif query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait':
|
||||
# 等待下一窗口
|
||||
await asyncio.sleep(window_size - time.time() % window_size)
|
||||
|
||||
|
||||
@@ -15,9 +15,7 @@ from ...core import entities as core_entities
|
||||
class SendResponseBackStage(stage.PipelineStage):
|
||||
"""发送响应消息"""
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理"""
|
||||
|
||||
random_range = (
|
||||
@@ -34,9 +32,7 @@ class SendResponseBackStage(stage.PipelineStage):
|
||||
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
|
||||
query.message_event, platform_events.GroupMessage
|
||||
):
|
||||
query.resp_message_chain[-1].insert(
|
||||
0, platform_message.At(query.message_event.sender.id)
|
||||
)
|
||||
query.resp_message_chain[-1].insert(0, platform_message.At(query.message_event.sender.id))
|
||||
|
||||
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
|
||||
|
||||
@@ -46,6 +42,4 @@ class SendResponseBackStage(stage.PipelineStage):
|
||||
quote_origin=quote_origin,
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
@@ -32,13 +32,9 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
await rule_inst.initialize()
|
||||
self.rule_matchers.append(rule_inst)
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
if query.launcher_type.value != 'group': # 只处理群消息
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
rules = query.pipeline_config['trigger']['group-respond-rules']
|
||||
|
||||
@@ -49,9 +45,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
# use_rule = rules[str(query.launcher_id)]
|
||||
|
||||
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
|
||||
res = await rule_matcher.match(
|
||||
str(query.message_chain), query.message_chain, use_rule, query
|
||||
)
|
||||
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)
|
||||
if res.matching:
|
||||
query.message_chain = res.replacement
|
||||
|
||||
@@ -60,6 +54,4 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
new_query=query,
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
|
||||
|
||||
@@ -16,10 +16,7 @@ class AtBotRule(rule_model.GroupRespondRule):
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query,
|
||||
) -> entities.RuleJudgeResult:
|
||||
if (
|
||||
message_chain.has(platform_message.At(query.adapter.bot_account_id))
|
||||
and rule_dict['at']
|
||||
):
|
||||
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
|
||||
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
|
||||
|
||||
if message_chain.has(
|
||||
|
||||
@@ -18,6 +18,4 @@ class RandomRespRule(rule_model.GroupRespondRule):
|
||||
) -> entities.RuleJudgeResult:
|
||||
random_rate = rule_dict['random']
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
matching=random.random() < random_rate, replacement=message_chain
|
||||
)
|
||||
return entities.RuleJudgeResult(matching=random.random() < random_rate, replacement=message_chain)
|
||||
|
||||
@@ -34,29 +34,19 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
if isinstance(query.resp_messages[-1], platform_message.MessageChain):
|
||||
query.resp_message_chain.append(query.resp_messages[-1])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
else:
|
||||
if query.resp_messages[-1].role == 'command':
|
||||
query.resp_message_chain.append(
|
||||
query.resp_messages[-1].get_content_platform_message_chain(
|
||||
prefix_text='[bot] '
|
||||
)
|
||||
query.resp_messages[-1].get_content_platform_message_chain(prefix_text='[bot] ')
|
||||
)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
elif query.resp_messages[-1].role == 'plugin':
|
||||
query.resp_message_chain.append(
|
||||
query.resp_messages[-1].get_content_platform_message_chain()
|
||||
)
|
||||
query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain())
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
if query.resp_messages[-1].role == 'assistant':
|
||||
result = query.resp_messages[-1]
|
||||
@@ -77,9 +67,7 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=[
|
||||
fc.function.name for fc in result.tool_calls
|
||||
]
|
||||
funcs_called=[fc.function.name for fc in result.tool_calls]
|
||||
if result.tool_calls is not None
|
||||
else [],
|
||||
query=query,
|
||||
@@ -92,36 +80,26 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
)
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
query.resp_message_chain.append(
|
||||
platform_message.MessageChain(event_ctx.event.reply)
|
||||
)
|
||||
query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
|
||||
|
||||
else:
|
||||
query.resp_message_chain.append(
|
||||
result.get_content_platform_message_chain()
|
||||
)
|
||||
query.resp_message_chain.append(result.get_content_platform_message_chain())
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query,
|
||||
)
|
||||
|
||||
if (
|
||||
result.tool_calls is not None and len(result.tool_calls) > 0
|
||||
): # 有函数调用
|
||||
if result.tool_calls is not None and len(result.tool_calls) > 0: # 有函数调用
|
||||
function_names = [tc.function.name for tc in result.tool_calls]
|
||||
|
||||
reply_text = f'调用函数 {".".join(function_names)}...'
|
||||
|
||||
query.resp_message_chain.append(
|
||||
platform_message.MessageChain(
|
||||
[platform_message.Plain(reply_text)]
|
||||
)
|
||||
platform_message.MessageChain([platform_message.Plain(reply_text)])
|
||||
)
|
||||
|
||||
if query.pipeline_config['output']['misc'][
|
||||
'track-function-calls'
|
||||
]:
|
||||
if query.pipeline_config['output']['misc']['track-function-calls']:
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
@@ -131,9 +109,7 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=[
|
||||
fc.function.name for fc in result.tool_calls
|
||||
]
|
||||
funcs_called=[fc.function.name for fc in result.tool_calls]
|
||||
if result.tool_calls is not None
|
||||
else [],
|
||||
query=query,
|
||||
@@ -148,16 +124,12 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
query.resp_message_chain.append(
|
||||
platform_message.MessageChain(
|
||||
event_ctx.event.reply
|
||||
)
|
||||
platform_message.MessageChain(event_ctx.event.reply)
|
||||
)
|
||||
|
||||
else:
|
||||
query.resp_message_chain.append(
|
||||
platform_message.MessageChain(
|
||||
[platform_message.Plain(reply_text)]
|
||||
)
|
||||
platform_message.MessageChain([platform_message.Plain(reply_text)])
|
||||
)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
|
||||
Reference in New Issue
Block a user