mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-12 16:56:02 +00:00
style: introduce ruff as linter and formatter (#1356)
* style: remove necessary imports * style: fix F841 * style: fix F401 * style: fix F811 * style: fix E402 * style: fix E721 * style: fix E722 * style: fix E722 * style: fix F541 * style: ruff format * style: all passed * style: add ruff in deps * style: more ignores in ruff.toml * style: add pre-commit
This commit is contained in:
committed by
GitHub
parent
09e70d70e9
commit
209f16af76
@@ -1,15 +1,13 @@
|
||||
from __future__ import annotations
|
||||
import re
|
||||
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
|
||||
@stage.stage_class('BanSessionCheckStage')
|
||||
class BanSessionCheckStage(stage.PipelineStage):
|
||||
"""访问控制处理阶段
|
||||
|
||||
|
||||
仅检查query中群号或个人号是否在访问控制列表中。
|
||||
"""
|
||||
|
||||
@@ -17,26 +15,24 @@ class BanSessionCheckStage(stage.PipelineStage):
|
||||
pass
|
||||
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
stage_inst_name: str
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
|
||||
found = False
|
||||
|
||||
mode = query.pipeline_config['trigger']['access-control']['mode']
|
||||
|
||||
sess_list = query.pipeline_config['trigger']['access-control'][mode]
|
||||
|
||||
if (query.launcher_type.value == 'group' and 'group_*' in sess_list) \
|
||||
or (query.launcher_type.value == 'person' and 'person_*' in sess_list):
|
||||
if (query.launcher_type.value == 'group' and 'group_*' in sess_list) or (
|
||||
query.launcher_type.value == 'person' and 'person_*' in sess_list
|
||||
):
|
||||
found = True
|
||||
else:
|
||||
for sess in sess_list:
|
||||
if sess == f"{query.launcher_type.value}_{query.launcher_id}":
|
||||
if sess == f'{query.launcher_type.value}_{query.launcher_id}':
|
||||
found = True
|
||||
break
|
||||
|
||||
|
||||
ctn = False
|
||||
|
||||
if mode == 'whitelist':
|
||||
@@ -45,7 +41,11 @@ class BanSessionCheckStage(stage.PipelineStage):
|
||||
ctn = not found
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE if ctn else entities.ResultType.INTERRUPT,
|
||||
result_type=entities.ResultType.CONTINUE
|
||||
if ctn
|
||||
else entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}' if not ctn else ''
|
||||
console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}'
|
||||
if not ctn
|
||||
else '',
|
||||
)
|
||||
|
||||
@@ -4,20 +4,21 @@ from ...core import app
|
||||
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from . import filter as filter_model, entities as filter_entities
|
||||
from .filters import cntignore, banwords, baiduexamine
|
||||
from ...provider import entities as llm_entities
|
||||
from ...platform.types import message as platform_message
|
||||
from ...platform.types import events as platform_events
|
||||
from ...platform.types import entities as platform_entities
|
||||
from ...utils import importutil
|
||||
|
||||
from . import filters
|
||||
|
||||
importutil.import_modules_in_pkg(filters)
|
||||
|
||||
|
||||
@stage.stage_class('PostContentFilterStage')
|
||||
@stage.stage_class('PreContentFilterStage')
|
||||
class ContentFilterStage(stage.PipelineStage):
|
||||
"""内容过滤阶段
|
||||
|
||||
|
||||
前置:
|
||||
检查消息是否符合规则,不符合则拦截。
|
||||
改写:
|
||||
@@ -36,13 +37,12 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
super().__init__(ap)
|
||||
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
|
||||
filters_required = [
|
||||
"content-ignore",
|
||||
'content-ignore',
|
||||
]
|
||||
|
||||
if pipeline_config['safety']['content-filter']['check-sensitive-words']:
|
||||
filters_required.append("ban-word-filter")
|
||||
filters_required.append('ban-word-filter')
|
||||
|
||||
# TODO revert it
|
||||
# if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
|
||||
@@ -50,9 +50,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
|
||||
for filter in filter_model.preregistered_filters:
|
||||
if filter.name in filters_required:
|
||||
self.filter_chain.append(
|
||||
filter(self.ap)
|
||||
)
|
||||
self.filter_chain.append(filter(self.ap))
|
||||
|
||||
for filter in self.filter_chain:
|
||||
await filter.initialize()
|
||||
@@ -68,8 +66,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
|
||||
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
else:
|
||||
for filter in self.filter_chain:
|
||||
@@ -78,26 +75,25 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
|
||||
if result.level in [
|
||||
filter_entities.ResultLevel.BLOCK,
|
||||
filter_entities.ResultLevel.MASKED
|
||||
filter_entities.ResultLevel.MASKED,
|
||||
]:
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
user_notice=result.user_notice,
|
||||
console_notice=result.console_notice
|
||||
console_notice=result.console_notice,
|
||||
)
|
||||
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
|
||||
message = result.replacement
|
||||
|
||||
|
||||
query.message_chain = platform_message.MessageChain(
|
||||
platform_message.Plain(message)
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
|
||||
|
||||
async def _post_process(
|
||||
self,
|
||||
message: str,
|
||||
@@ -108,8 +104,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
"""
|
||||
if query.pipeline_config['safety']['content-filter']['scope'] == 'income-msg':
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
else:
|
||||
message = message.strip()
|
||||
@@ -122,30 +117,25 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
user_notice=result.user_notice,
|
||||
console_notice=result.console_notice
|
||||
console_notice=result.console_notice,
|
||||
)
|
||||
elif result.level in [
|
||||
filter_entities.ResultLevel.PASS,
|
||||
filter_entities.ResultLevel.MASKED
|
||||
filter_entities.ResultLevel.MASKED,
|
||||
]:
|
||||
message = result.replacement
|
||||
|
||||
query.resp_messages[-1].content = message
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
stage_inst_name: str
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
"""处理
|
||||
"""
|
||||
"""处理"""
|
||||
if stage_inst_name == 'PreContentFilterStage':
|
||||
|
||||
contain_non_text = False
|
||||
|
||||
text_components = [platform_message.Plain, platform_message.Source]
|
||||
@@ -156,28 +146,24 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
break
|
||||
|
||||
if contain_non_text:
|
||||
self.ap.logger.debug(f"消息中包含非文本消息,跳过内容过滤器检查。")
|
||||
self.ap.logger.debug('消息中包含非文本消息,跳过内容过滤器检查。')
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
|
||||
return await self._pre_process(
|
||||
str(query.message_chain).strip(),
|
||||
query
|
||||
)
|
||||
return await self._pre_process(str(query.message_chain).strip(), query)
|
||||
elif stage_inst_name == 'PostContentFilterStage':
|
||||
# 仅处理 query.resp_messages[-1].content 是 str 的情况
|
||||
if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(query.resp_messages[-1].content, str):
|
||||
return await self._post_process(
|
||||
query.resp_messages[-1].content,
|
||||
query
|
||||
)
|
||||
if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(
|
||||
query.resp_messages[-1].content, str
|
||||
):
|
||||
return await self._post_process(query.resp_messages[-1].content, query)
|
||||
else:
|
||||
self.ap.logger.debug(f"resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。")
|
||||
self.ap.logger.debug(
|
||||
'resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。'
|
||||
)
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
|
||||
import typing
|
||||
import enum
|
||||
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from ...provider import entities as llm_entities
|
||||
|
||||
|
||||
class ResultLevel(enum.Enum):
|
||||
"""结果等级"""
|
||||
|
||||
PASS = enum.auto()
|
||||
"""通过"""
|
||||
|
||||
@@ -24,6 +21,7 @@ class ResultLevel(enum.Enum):
|
||||
|
||||
class EnableStage(enum.Enum):
|
||||
"""启用阶段"""
|
||||
|
||||
PRE = enum.auto()
|
||||
"""预处理"""
|
||||
|
||||
@@ -55,14 +53,15 @@ class FilterResult(pydantic.BaseModel):
|
||||
|
||||
class ManagerResultLevel(enum.Enum):
|
||||
"""处理器结果等级"""
|
||||
|
||||
CONTINUE = enum.auto()
|
||||
"""继续"""
|
||||
|
||||
INTERRUPT = enum.auto()
|
||||
"""中断"""
|
||||
|
||||
class FilterManagerResult(pydantic.BaseModel):
|
||||
|
||||
class FilterManagerResult(pydantic.BaseModel):
|
||||
level: ManagerResultLevel
|
||||
|
||||
replacement: str
|
||||
|
||||
@@ -5,14 +5,13 @@ import typing
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from . import entities
|
||||
from ...provider import entities as llm_entities
|
||||
|
||||
|
||||
preregistered_filters: list[typing.Type[ContentFilter]] = []
|
||||
|
||||
|
||||
def filter_class(
|
||||
name: str
|
||||
name: str,
|
||||
) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]:
|
||||
"""内容过滤器类装饰器
|
||||
|
||||
@@ -22,6 +21,7 @@ def filter_class(
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: 装饰器
|
||||
"""
|
||||
|
||||
def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]:
|
||||
assert issubclass(cls, ContentFilter)
|
||||
|
||||
@@ -53,23 +53,21 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
||||
entity.EnableStage.PRE: 消息请求AI前,此时需要检查的内容是用户的输入消息。
|
||||
entity.EnableStage.POST: 消息请求AI后,此时需要检查的内容是AI的回复消息。
|
||||
"""
|
||||
return [
|
||||
entities.EnableStage.PRE,
|
||||
entities.EnableStage.POST
|
||||
]
|
||||
return [entities.EnableStage.PRE, entities.EnableStage.POST]
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化过滤器
|
||||
"""
|
||||
"""初始化过滤器"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, query: core_entities.Query, message: str=None, image_url=None) -> entities.FilterResult:
|
||||
async def process(
|
||||
self, query: core_entities.Query, message: str = None, image_url=None
|
||||
) -> entities.FilterResult:
|
||||
"""处理消息
|
||||
|
||||
分为前后阶段,具体取决于 enable_stages 的值。
|
||||
对于内容过滤器来说,不需要考虑消息所处的阶段,只需要检查消息内容即可。
|
||||
|
||||
|
||||
Args:
|
||||
message (str): 需要检查的内容
|
||||
image_url (str): 要检查的图片的 URL
|
||||
|
||||
@@ -7,11 +7,11 @@ from .. import filter as filter_model
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}"
|
||||
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}'
|
||||
BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
|
||||
|
||||
|
||||
@filter_model.filter_class("baidu-cloud-examine")
|
||||
@filter_model.filter_class('baidu-cloud-examine')
|
||||
class BaiduCloudExamine(filter_model.ContentFilter):
|
||||
"""百度云内容审核"""
|
||||
|
||||
@@ -20,44 +20,52 @@ class BaiduCloudExamine(filter_model.ContentFilter):
|
||||
async with session.post(
|
||||
BAIDU_EXAMINE_TOKEN_URL,
|
||||
params={
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'],
|
||||
"client_secret": self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret']
|
||||
}
|
||||
'grant_type': 'client_credentials',
|
||||
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine'][
|
||||
'api-key'
|
||||
],
|
||||
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine'][
|
||||
'api-secret'
|
||||
],
|
||||
},
|
||||
) as resp:
|
||||
return (await resp.json())['access_token']
|
||||
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, message: str
|
||||
) -> entities.FilterResult:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
BAIDU_EXAMINE_URL.format(await self._get_token()),
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'},
|
||||
data=f"text={message}".encode('utf-8')
|
||||
headers={
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
data=f'text={message}'.encode('utf-8'),
|
||||
) as resp:
|
||||
result = await resp.json()
|
||||
|
||||
if "error_code" in result:
|
||||
if 'error_code' in result:
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
replacement=message,
|
||||
user_notice='',
|
||||
console_notice=f"百度云判定出错,错误信息:{result['error_msg']}"
|
||||
console_notice=f'百度云判定出错,错误信息:{result["error_msg"]}',
|
||||
)
|
||||
else:
|
||||
conclusion = result["conclusion"]
|
||||
conclusion = result['conclusion']
|
||||
|
||||
if conclusion in ("合规"):
|
||||
if conclusion in ('合规'):
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.PASS,
|
||||
replacement=message,
|
||||
user_notice='',
|
||||
console_notice=f"百度云判定结果:{conclusion}"
|
||||
console_notice=f'百度云判定结果:{conclusion}',
|
||||
)
|
||||
else:
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
replacement=message,
|
||||
user_notice="消息中存在不合适的内容, 请修改",
|
||||
console_notice=f"百度云判定结果:{conclusion}"
|
||||
user_notice='消息中存在不合适的内容, 请修改',
|
||||
console_notice=f'百度云判定结果:{conclusion}',
|
||||
)
|
||||
|
||||
@@ -6,14 +6,16 @@ from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@filter_model.filter_class("ban-word-filter")
|
||||
@filter_model.filter_class('ban-word-filter')
|
||||
class BanWordFilter(filter_model.ContentFilter):
|
||||
"""根据内容过滤"""
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
async def process(
|
||||
self, query: core_entities.Query, message: str
|
||||
) -> entities.FilterResult:
|
||||
found = False
|
||||
|
||||
for word in self.ap.sensitive_meta.data['words']:
|
||||
@@ -23,9 +25,10 @@ class BanWordFilter(filter_model.ContentFilter):
|
||||
found = True
|
||||
|
||||
for i in range(len(match)):
|
||||
if self.ap.sensitive_meta.data['mask_word'] == "":
|
||||
if self.ap.sensitive_meta.data['mask_word'] == '':
|
||||
message = message.replace(
|
||||
match[i], self.ap.sensitive_meta.data['mask'] * len(match[i])
|
||||
match[i],
|
||||
self.ap.sensitive_meta.data['mask'] * len(match[i]),
|
||||
)
|
||||
else:
|
||||
message = message.replace(
|
||||
@@ -36,5 +39,5 @@ class BanWordFilter(filter_model.ContentFilter):
|
||||
level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS,
|
||||
replacement=message,
|
||||
user_notice='消息中存在不合适的内容, 请修改' if found else '',
|
||||
console_notice=''
|
||||
)
|
||||
console_notice='',
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ from .. import filter as filter_model
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@filter_model.filter_class("content-ignore")
|
||||
@filter_model.filter_class('content-ignore')
|
||||
class ContentIgnore(filter_model.ContentFilter):
|
||||
"""根据内容忽略消息"""
|
||||
|
||||
@@ -16,7 +16,9 @@ class ContentIgnore(filter_model.ContentFilter):
|
||||
entities.EnableStage.PRE,
|
||||
]
|
||||
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
async def process(
|
||||
self, query: core_entities.Query, message: str
|
||||
) -> entities.FilterResult:
|
||||
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
|
||||
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
|
||||
if message.startswith(rule):
|
||||
@@ -24,9 +26,9 @@ class ContentIgnore(filter_model.ContentFilter):
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
replacement='',
|
||||
user_notice='',
|
||||
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息'
|
||||
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息',
|
||||
)
|
||||
|
||||
|
||||
if 'regexp' in query.pipeline_config['trigger']['ignore-rules']:
|
||||
for rule in query.pipeline_config['trigger']['ignore-rules']['regexp']:
|
||||
if re.search(rule, message):
|
||||
@@ -34,12 +36,12 @@ class ContentIgnore(filter_model.ContentFilter):
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
replacement='',
|
||||
user_notice='',
|
||||
console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息'
|
||||
console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息',
|
||||
)
|
||||
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.PASS,
|
||||
replacement=message,
|
||||
user_notice='',
|
||||
console_notice=''
|
||||
)
|
||||
console_notice='',
|
||||
)
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
from ..core import app, entities
|
||||
from . import entities as pipeline_entities
|
||||
from ..plugin import events
|
||||
from ..platform.types import message as platform_message
|
||||
|
||||
|
||||
class Controller:
|
||||
"""总控制器
|
||||
"""
|
||||
"""总控制器"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
semaphore: asyncio.Semaphore = None
|
||||
@@ -20,11 +16,12 @@ class Controller:
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.semaphore = asyncio.Semaphore(self.ap.instance_config.data['concurrency']['pipeline'])
|
||||
self.semaphore = asyncio.Semaphore(
|
||||
self.ap.instance_config.data['concurrency']['pipeline']
|
||||
)
|
||||
|
||||
async def consumer(self):
|
||||
"""事件处理循环
|
||||
"""
|
||||
"""事件处理循环"""
|
||||
try:
|
||||
while True:
|
||||
selected_query: entities.Query = None
|
||||
@@ -35,7 +32,9 @@ class Controller:
|
||||
|
||||
for query in queries:
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
self.ap.logger.debug(f"Checking query {query} session {session}")
|
||||
self.ap.logger.debug(
|
||||
f'Checking query {query} session {session}'
|
||||
)
|
||||
|
||||
if not session.semaphore.locked():
|
||||
selected_query = query
|
||||
@@ -56,30 +55,40 @@ class Controller:
|
||||
# find pipeline
|
||||
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
|
||||
# Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected.
|
||||
bot = await self.ap.platform_mgr.get_bot_by_uuid(selected_query.bot_uuid)
|
||||
bot = await self.ap.platform_mgr.get_bot_by_uuid(
|
||||
selected_query.bot_uuid
|
||||
)
|
||||
if bot:
|
||||
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(bot.bot_entity.use_pipeline_uuid)
|
||||
pipeline = (
|
||||
await self.ap.pipeline_mgr.get_pipeline_by_uuid(
|
||||
bot.bot_entity.use_pipeline_uuid
|
||||
)
|
||||
)
|
||||
if pipeline:
|
||||
await pipeline.run(selected_query)
|
||||
|
||||
|
||||
async with self.ap.query_pool:
|
||||
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
|
||||
(
|
||||
await self.ap.sess_mgr.get_session(selected_query)
|
||||
).semaphore.release()
|
||||
# 通知其他协程,有新的请求可以处理了
|
||||
self.ap.query_pool.condition.notify_all()
|
||||
|
||||
self.ap.task_mgr.create_task(
|
||||
_process_query(selected_query),
|
||||
kind="query",
|
||||
name=f"query-{selected_query.query_id}",
|
||||
scopes=[entities.LifecycleControlScope.APPLICATION, entities.LifecycleControlScope.PLATFORM],
|
||||
kind='query',
|
||||
name=f'query-{selected_query.query_id}',
|
||||
scopes=[
|
||||
entities.LifecycleControlScope.APPLICATION,
|
||||
entities.LifecycleControlScope.PLATFORM,
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# traceback.print_exc()
|
||||
self.ap.logger.error(f"控制器循环出错: {e}")
|
||||
self.ap.logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
self.ap.logger.error(f'控制器循环出错: {e}')
|
||||
self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
|
||||
|
||||
async def run(self):
|
||||
"""运行控制器
|
||||
"""
|
||||
"""运行控制器"""
|
||||
await self.consumer()
|
||||
|
||||
@@ -10,7 +10,6 @@ from ..core import entities
|
||||
|
||||
|
||||
class ResultType(enum.Enum):
|
||||
|
||||
CONTINUE = enum.auto()
|
||||
"""继续流水线"""
|
||||
|
||||
@@ -19,12 +18,18 @@ class ResultType(enum.Enum):
|
||||
|
||||
|
||||
class StageProcessResult(pydantic.BaseModel):
|
||||
|
||||
result_type: ResultType
|
||||
|
||||
new_query: entities.Query
|
||||
|
||||
user_notice: typing.Optional[typing.Union[str, list[platform_message.MessageComponent], platform_message.MessageChain, None]] = []
|
||||
user_notice: typing.Optional[
|
||||
typing.Union[
|
||||
str,
|
||||
list[platform_message.MessageComponent],
|
||||
platform_message.MessageChain,
|
||||
None,
|
||||
]
|
||||
] = []
|
||||
"""只要设置了就会发送给用户"""
|
||||
|
||||
console_notice: typing.Optional[str] = ''
|
||||
|
||||
@@ -2,18 +2,19 @@ from __future__ import annotations
|
||||
import os
|
||||
import traceback
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from ...core import app
|
||||
from . import strategy
|
||||
from .strategies import image, forward
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from ...platform.types import message as platform_message
|
||||
from ...utils import importutil
|
||||
|
||||
from . import strategies
|
||||
|
||||
importutil.import_modules_in_pkg(strategies)
|
||||
|
||||
|
||||
@stage.stage_class("LongTextProcessStage")
|
||||
@stage.stage_class('LongTextProcessStage')
|
||||
class LongTextProcessStage(stage.PipelineStage):
|
||||
"""长消息处理阶段
|
||||
|
||||
@@ -31,34 +32,48 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
# 检查是否存在
|
||||
if not os.path.exists(use_font):
|
||||
# 若是windows系统,使用微软雅黑
|
||||
if os.name == "nt":
|
||||
use_font = "C:/Windows/Fonts/msyh.ttc"
|
||||
if os.name == 'nt':
|
||||
use_font = 'C:/Windows/Fonts/msyh.ttc'
|
||||
if not os.path.exists(use_font):
|
||||
self.ap.logger.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。")
|
||||
config['blob_message_strategy'] = "forward"
|
||||
self.ap.logger.warn(
|
||||
'未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。'
|
||||
)
|
||||
config['blob_message_strategy'] = 'forward'
|
||||
else:
|
||||
self.ap.logger.info("使用Windows自带字体:" + use_font)
|
||||
self.ap.logger.info('使用Windows自带字体:' + use_font)
|
||||
config['font-path'] = use_font
|
||||
else:
|
||||
self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。")
|
||||
self.ap.logger.warn(
|
||||
'未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。'
|
||||
)
|
||||
|
||||
pipeline_config['output']['long-text-processing']['strategy'] = "forward"
|
||||
except:
|
||||
pipeline_config['output']['long-text-processing'][
|
||||
'strategy'
|
||||
] = 'forward'
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。".format(use_font))
|
||||
self.ap.logger.error(
|
||||
'加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。'.format(
|
||||
use_font
|
||||
)
|
||||
)
|
||||
|
||||
pipeline_config['output']['long-text-processing']['strategy'] = "forward"
|
||||
pipeline_config['output']['long-text-processing']['strategy'] = (
|
||||
'forward'
|
||||
)
|
||||
|
||||
for strategy_cls in strategy.preregistered_strategies:
|
||||
if strategy_cls.name == config['strategy']:
|
||||
self.strategy_impl = strategy_cls(self.ap)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"未找到名为 {config['strategy']} 的长消息处理策略")
|
||||
raise ValueError(f'未找到名为 {config["strategy"]} 的长消息处理策略')
|
||||
|
||||
await self.strategy_impl.initialize()
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
# 检查是否包含非 Plain 组件
|
||||
contains_non_plain = False
|
||||
|
||||
@@ -66,13 +81,19 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
if not isinstance(msg, platform_message.Plain):
|
||||
contains_non_plain = True
|
||||
break
|
||||
|
||||
|
||||
if contains_non_plain:
|
||||
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
|
||||
elif len(str(query.resp_message_chain[-1])) > query.pipeline_config['output']['long-text-processing']['threshold']:
|
||||
query.resp_message_chain[-1] = platform_message.MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
|
||||
self.ap.logger.debug('消息中包含非 Plain 组件,跳过长消息处理。')
|
||||
elif (
|
||||
len(str(query.resp_message_chain[-1]))
|
||||
> query.pipeline_config['output']['long-text-processing']['threshold']
|
||||
):
|
||||
query.resp_message_chain[-1] = platform_message.MessageChain(
|
||||
await self.strategy_impl.process(
|
||||
str(query.resp_message_chain[-1]), query
|
||||
)
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# 转发消息组件
|
||||
from __future__ import annotations
|
||||
import typing
|
||||
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from .. import strategy as strategy_model
|
||||
from ....core import entities as core_entities
|
||||
@@ -13,29 +11,27 @@ ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay
|
||||
Forward = platform_message.Forward
|
||||
|
||||
|
||||
@strategy_model.strategy_class("forward")
|
||||
@strategy_model.strategy_class('forward')
|
||||
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
|
||||
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
async def process(
|
||||
self, message: str, query: core_entities.Query
|
||||
) -> list[platform_message.MessageComponent]:
|
||||
display = ForwardMessageDiaplay(
|
||||
title="群聊的聊天记录",
|
||||
brief="[聊天记录]",
|
||||
source="聊天记录",
|
||||
preview=["QQ用户: "+message],
|
||||
summary="查看1条转发消息"
|
||||
title='群聊的聊天记录',
|
||||
brief='[聊天记录]',
|
||||
source='聊天记录',
|
||||
preview=['QQ用户: ' + message],
|
||||
summary='查看1条转发消息',
|
||||
)
|
||||
|
||||
node_list = [
|
||||
platform_message.ForwardMessageNode(
|
||||
sender_id=query.adapter.bot_account_id,
|
||||
sender_name='QQ用户',
|
||||
message_chain=platform_message.MessageChain([message])
|
||||
message_chain=platform_message.MessageChain([message]),
|
||||
)
|
||||
]
|
||||
|
||||
forward = Forward(
|
||||
display=display,
|
||||
node_list=node_list
|
||||
)
|
||||
forward = Forward(display=display, node_list=node_list)
|
||||
|
||||
return [forward]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import os
|
||||
import base64
|
||||
import time
|
||||
@@ -15,26 +14,30 @@ from .. import strategy as strategy_model
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@strategy_model.strategy_class("image")
|
||||
@strategy_model.strategy_class('image')
|
||||
class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@functools.lru_cache(maxsize=16)
|
||||
def get_font(self, query: core_entities.Query):
|
||||
return ImageFont.truetype(query.pipeline_config['output']['long-text-processing']['font-path'], 32, encoding="utf-8")
|
||||
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
return ImageFont.truetype(
|
||||
query.pipeline_config['output']['long-text-processing']['font-path'],
|
||||
32,
|
||||
encoding='utf-8',
|
||||
)
|
||||
|
||||
async def process(
|
||||
self, message: str, query: core_entities.Query
|
||||
) -> list[platform_message.MessageComponent]:
|
||||
img_path = self.text_to_image(
|
||||
text_str=message,
|
||||
save_as='temp/{}.png'.format(int(time.time())),
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
compressed_path, size = self.compress_image(
|
||||
img_path,
|
||||
outfile="temp/{}_compressed.png".format(int(time.time()))
|
||||
img_path, outfile='temp/{}_compressed.png'.format(int(time.time()))
|
||||
)
|
||||
|
||||
with open(compressed_path, 'rb') as f:
|
||||
@@ -93,13 +96,11 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
resultIndex.append(v)
|
||||
return resultIndex
|
||||
|
||||
|
||||
def get_size(self, file):
|
||||
# 获取文件大小:KB
|
||||
size = os.path.getsize(file)
|
||||
return size / 1024
|
||||
|
||||
|
||||
def get_outfile(self, infile, outfile):
|
||||
if outfile:
|
||||
return outfile
|
||||
@@ -107,7 +108,6 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
outfile = '{}-out{}'.format(dir, suffix)
|
||||
return outfile
|
||||
|
||||
|
||||
def compress_image(self, infile, outfile='', kb=100, step=20, quality=90):
|
||||
"""不改变图片尺寸压缩到指定大小
|
||||
:param infile: 压缩源文件
|
||||
@@ -130,24 +130,28 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
o_size = self.get_size(outfile)
|
||||
return outfile, self.get_size(outfile)
|
||||
|
||||
def text_to_image(
|
||||
self,
|
||||
text_str: str,
|
||||
save_as='temp.png',
|
||||
width=800,
|
||||
query: core_entities.Query = None,
|
||||
):
|
||||
text_str = text_str.replace('\t', ' ')
|
||||
|
||||
def text_to_image(self, text_str: str, save_as="temp.png", width=800, query: core_entities.Query = None):
|
||||
|
||||
text_str = text_str.replace("\t", " ")
|
||||
|
||||
# 分行
|
||||
lines = text_str.split('\n')
|
||||
|
||||
# 计算并分割
|
||||
final_lines = []
|
||||
|
||||
text_width = width-80
|
||||
text_width = width - 80
|
||||
|
||||
self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width))
|
||||
self.ap.logger.debug('lines: {}, text_width: {}'.format(lines, text_width))
|
||||
for line in lines:
|
||||
# 如果长了就分割
|
||||
line_width = self.get_font(query).getlength(line)
|
||||
self.ap.logger.debug("line_width: {}".format(line_width))
|
||||
self.ap.logger.debug('line_width: {}'.format(line_width))
|
||||
if line_width < text_width:
|
||||
final_lines.append(line)
|
||||
continue
|
||||
@@ -161,7 +165,10 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
numbers = self.indexNumber(rest_text)
|
||||
|
||||
for number in numbers:
|
||||
if number[1] < point < number[1] + len(number[0]) and number[1] != 0:
|
||||
if (
|
||||
number[1] < point < number[1] + len(number[0])
|
||||
and number[1] != 0
|
||||
):
|
||||
point = number[1]
|
||||
break
|
||||
|
||||
@@ -174,16 +181,23 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
else:
|
||||
continue
|
||||
# 准备画布
|
||||
img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255))
|
||||
img = Image.new(
|
||||
'RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255)
|
||||
)
|
||||
draw = ImageDraw.Draw(img, mode='RGBA')
|
||||
|
||||
self.ap.logger.debug("正在绘制图片...")
|
||||
self.ap.logger.debug('正在绘制图片...')
|
||||
# 绘制正文
|
||||
line_number = 0
|
||||
offset_x = 20
|
||||
offset_y = 30
|
||||
for final_line in final_lines:
|
||||
draw.text((offset_x, offset_y + 35 * line_number), final_line, fill=(0, 0, 0), font=self.text_render_font)
|
||||
draw.text(
|
||||
(offset_x, offset_y + 35 * line_number),
|
||||
final_line,
|
||||
fill=(0, 0, 0),
|
||||
font=self.text_render_font,
|
||||
)
|
||||
# 遍历此行,检查是否有emoji
|
||||
idx_in_line = 0
|
||||
for ch in final_line:
|
||||
@@ -196,7 +210,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
|
||||
line_number += 1
|
||||
|
||||
self.ap.logger.debug("正在保存图片...")
|
||||
self.ap.logger.debug('正在保存图片...')
|
||||
img.save(save_as)
|
||||
|
||||
return save_as
|
||||
|
||||
@@ -12,7 +12,7 @@ preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
|
||||
|
||||
|
||||
def strategy_class(
|
||||
name: str
|
||||
name: str,
|
||||
) -> typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]:
|
||||
"""长文本处理策略类装饰器
|
||||
|
||||
@@ -36,8 +36,7 @@ def strategy_class(
|
||||
|
||||
|
||||
class LongTextStrategy(metaclass=abc.ABCMeta):
|
||||
"""长文本处理策略抽象类
|
||||
"""
|
||||
"""长文本处理策略抽象类"""
|
||||
|
||||
name: str
|
||||
|
||||
@@ -45,12 +44,14 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
async def process(
|
||||
self, message: str, query: core_entities.Query
|
||||
) -> list[platform_message.MessageComponent]:
|
||||
"""处理长文本
|
||||
|
||||
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法
|
||||
|
||||
@@ -3,33 +3,38 @@ from __future__ import annotations
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from . import truncator
|
||||
from .truncators import round
|
||||
from ...utils import importutil
|
||||
|
||||
from . import truncators
|
||||
|
||||
importutil.import_modules_in_pkg(truncators)
|
||||
|
||||
|
||||
@stage.stage_class("ConversationMessageTruncator")
|
||||
@stage.stage_class('ConversationMessageTruncator')
|
||||
class ConversationMessageTruncator(stage.PipelineStage):
|
||||
"""会话消息截断器
|
||||
|
||||
用于截断会话消息链,以适应平台消息长度限制。
|
||||
"""
|
||||
|
||||
trun: truncator.Truncator
|
||||
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
use_method = "round"
|
||||
use_method = 'round'
|
||||
|
||||
for trun in truncator.preregistered_truncators:
|
||||
if trun.name == use_method:
|
||||
self.trun = trun(self.ap)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"未知的截断器: {use_method}")
|
||||
raise ValueError(f'未知的截断器: {use_method}')
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理
|
||||
"""
|
||||
async def process(
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
"""处理"""
|
||||
query = await self.trun.truncate(query)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ preregistered_truncators: list[typing.Type[Truncator]] = []
|
||||
|
||||
|
||||
def truncator_class(
|
||||
name: str
|
||||
name: str,
|
||||
) -> typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]:
|
||||
"""截断器类装饰器
|
||||
|
||||
@@ -20,6 +20,7 @@ def truncator_class(
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]: 装饰器
|
||||
"""
|
||||
|
||||
def decorator(cls: typing.Type[Truncator]) -> typing.Type[Truncator]:
|
||||
assert issubclass(cls, Truncator)
|
||||
|
||||
@@ -33,13 +34,12 @@ def truncator_class(
|
||||
|
||||
|
||||
class Truncator(abc.ABC):
|
||||
"""消息截断器基类
|
||||
"""
|
||||
"""消息截断器基类"""
|
||||
|
||||
name: str
|
||||
|
||||
ap: app.Application
|
||||
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
|
||||
@@ -4,14 +4,12 @@ from .. import truncator
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@truncator.truncator_class("round")
|
||||
@truncator.truncator_class('round')
|
||||
class RoundTruncator(truncator.Truncator):
|
||||
"""前文回合数阶段器
|
||||
"""
|
||||
"""前文回合数阶段器"""
|
||||
|
||||
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
|
||||
"""截断
|
||||
"""
|
||||
"""截断"""
|
||||
max_round = query.pipeline_config['ai']['local-agent']['max-round']
|
||||
|
||||
temp_messages = []
|
||||
@@ -26,7 +24,7 @@ class RoundTruncator(truncator.Truncator):
|
||||
current_round += 1
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
query.messages = temp_messages[::-1]
|
||||
|
||||
return query
|
||||
|
||||
@@ -11,22 +11,39 @@ from ..entity.persistence import pipeline as persistence_pipeline
|
||||
from . import stage
|
||||
from ..platform.types import message as platform_message, events as platform_events
|
||||
from ..plugin import events
|
||||
from ..utils import importutil
|
||||
|
||||
from .resprule import resprule
|
||||
from .bansess import bansess
|
||||
from .cntfilter import cntfilter
|
||||
from .process import process
|
||||
from .longtext import longtext
|
||||
from .respback import respback
|
||||
from .wrapper import wrapper
|
||||
from .preproc import preproc
|
||||
from .ratelimit import ratelimit
|
||||
from .msgtrun import msgtrun
|
||||
from . import (
|
||||
resprule,
|
||||
bansess,
|
||||
cntfilter,
|
||||
process,
|
||||
longtext,
|
||||
respback,
|
||||
wrapper,
|
||||
preproc,
|
||||
ratelimit,
|
||||
msgtrun,
|
||||
)
|
||||
|
||||
importutil.import_modules_in_pkgs(
|
||||
[
|
||||
resprule,
|
||||
bansess,
|
||||
cntfilter,
|
||||
process,
|
||||
longtext,
|
||||
respback,
|
||||
wrapper,
|
||||
preproc,
|
||||
ratelimit,
|
||||
msgtrun,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class StageInstContainer():
|
||||
"""阶段实例容器
|
||||
"""
|
||||
class StageInstContainer:
|
||||
"""阶段实例容器"""
|
||||
|
||||
inst_name: str
|
||||
|
||||
@@ -48,7 +65,12 @@ class RuntimePipeline:
|
||||
stage_containers: list[StageInstContainer]
|
||||
"""阶段实例容器"""
|
||||
|
||||
def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[StageInstContainer]):
|
||||
def __init__(
|
||||
self,
|
||||
ap: app.Application,
|
||||
pipeline_entity: persistence_pipeline.LegacyPipeline,
|
||||
stage_containers: list[StageInstContainer],
|
||||
):
|
||||
self.ap = ap
|
||||
self.pipeline_entity = pipeline_entity
|
||||
self.stage_containers = stage_containers
|
||||
@@ -57,9 +79,10 @@ class RuntimePipeline:
|
||||
query.pipeline_config = self.pipeline_entity.config
|
||||
await self.process_query(query)
|
||||
|
||||
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
|
||||
"""检查输出
|
||||
"""
|
||||
async def _check_output(
|
||||
self, query: entities.Query, result: pipeline_entities.StageProcessResult
|
||||
):
|
||||
"""检查输出"""
|
||||
if result.user_notice:
|
||||
# 处理str类型
|
||||
|
||||
@@ -68,22 +91,19 @@ class RuntimePipeline:
|
||||
platform_message.Plain(result.user_notice)
|
||||
)
|
||||
elif isinstance(result.user_notice, list):
|
||||
result.user_notice = platform_message.MessageChain(
|
||||
*result.user_notice
|
||||
)
|
||||
result.user_notice = platform_message.MessageChain(*result.user_notice)
|
||||
|
||||
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage):
|
||||
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
|
||||
query.message_event, platform_events.GroupMessage
|
||||
):
|
||||
result.user_notice.insert(
|
||||
0,
|
||||
platform_message.At(
|
||||
query.message_event.sender.id
|
||||
)
|
||||
0, platform_message.At(query.message_event.sender.id)
|
||||
)
|
||||
|
||||
await query.adapter.reply_message(
|
||||
message_source=query.message_event,
|
||||
message=result.user_notice,
|
||||
quote_origin=query.pipeline_config['output']['misc']['quote-origin']
|
||||
quote_origin=query.pipeline_config['output']['misc']['quote-origin'],
|
||||
)
|
||||
if result.debug_notice:
|
||||
self.ap.logger.debug(result.debug_notice)
|
||||
@@ -123,32 +143,44 @@ class RuntimePipeline:
|
||||
stage_container = self.stage_containers[i]
|
||||
|
||||
query.current_stage = stage_container # 标记到 Query 对象里
|
||||
|
||||
|
||||
result = stage_container.inst.process(query, stage_container.inst_name)
|
||||
|
||||
if isinstance(result, typing.Coroutine):
|
||||
result = await result
|
||||
|
||||
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}")
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} processed query {query} res {result}'
|
||||
)
|
||||
await self._check_output(query, result)
|
||||
|
||||
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} interrupted query {query}'
|
||||
)
|
||||
break
|
||||
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = result.new_query
|
||||
elif isinstance(result, typing.AsyncGenerator): # 生成器
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen")
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} processed query {query} gen'
|
||||
)
|
||||
|
||||
async for sub_result in result:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}")
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} processed query {query} res {sub_result}'
|
||||
)
|
||||
await self._check_output(query, sub_result)
|
||||
|
||||
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} interrupted query {query}'
|
||||
)
|
||||
break
|
||||
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
elif (
|
||||
sub_result.result_type == pipeline_entities.ResultType.CONTINUE
|
||||
):
|
||||
query = sub_result.new_query
|
||||
await self._execute_from_stage(i + 1, query)
|
||||
break
|
||||
@@ -156,12 +188,14 @@ class RuntimePipeline:
|
||||
i += 1
|
||||
|
||||
async def process_query(self, query: entities.Query):
|
||||
"""处理请求
|
||||
"""
|
||||
"""处理请求"""
|
||||
try:
|
||||
|
||||
# ======== 触发 MessageReceived 事件 ========
|
||||
event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived
|
||||
event_type = (
|
||||
events.PersonMessageReceived
|
||||
if query.launcher_type == entities.LauncherTypes.PERSON
|
||||
else events.GroupMessageReceived
|
||||
)
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=event_type(
|
||||
@@ -169,22 +203,26 @@ class RuntimePipeline:
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
message_chain=query.message_chain,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
return
|
||||
|
||||
self.ap.logger.debug(f"Processing query {query}")
|
||||
|
||||
self.ap.logger.debug(f'Processing query {query}')
|
||||
|
||||
await self._execute_from_stage(0, query)
|
||||
except Exception as e:
|
||||
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
|
||||
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}")
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
inst_name = (
|
||||
query.current_stage.inst_name if query.current_stage else 'unknown'
|
||||
)
|
||||
self.ap.logger.error(
|
||||
f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}'
|
||||
)
|
||||
self.ap.logger.debug(f'Traceback: {traceback.format_exc()}')
|
||||
finally:
|
||||
self.ap.logger.debug(f"Query {query} processed")
|
||||
self.ap.logger.debug(f'Query {query} processed')
|
||||
|
||||
|
||||
class PipelineManager:
|
||||
@@ -203,7 +241,9 @@ class PipelineManager:
|
||||
self.pipelines = []
|
||||
|
||||
async def initialize(self):
|
||||
self.stage_dict = {name: cls for name, cls in stage.preregistered_stages.items()}
|
||||
self.stage_dict = {
|
||||
name: cls for name, cls in stage.preregistered_stages.items()
|
||||
}
|
||||
|
||||
await self.load_pipelines_from_db()
|
||||
|
||||
@@ -220,24 +260,31 @@ class PipelineManager:
|
||||
for pipeline in pipelines:
|
||||
await self.load_pipeline(pipeline)
|
||||
|
||||
async def load_pipeline(self, pipeline_entity: persistence_pipeline.LegacyPipeline | sqlalchemy.Row[persistence_pipeline.LegacyPipeline] | dict):
|
||||
|
||||
async def load_pipeline(
|
||||
self,
|
||||
pipeline_entity: persistence_pipeline.LegacyPipeline
|
||||
| sqlalchemy.Row[persistence_pipeline.LegacyPipeline]
|
||||
| dict,
|
||||
):
|
||||
if isinstance(pipeline_entity, sqlalchemy.Row):
|
||||
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity._mapping)
|
||||
pipeline_entity = persistence_pipeline.LegacyPipeline(
|
||||
**pipeline_entity._mapping
|
||||
)
|
||||
elif isinstance(pipeline_entity, dict):
|
||||
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
|
||||
|
||||
# initialize stage containers according to pipeline_entity.stages
|
||||
stage_containers: list[StageInstContainer] = []
|
||||
for stage_name in pipeline_entity.stages:
|
||||
stage_containers.append(StageInstContainer(
|
||||
inst_name=stage_name,
|
||||
inst=self.stage_dict[stage_name](self.ap)
|
||||
))
|
||||
stage_containers.append(
|
||||
StageInstContainer(
|
||||
inst_name=stage_name, inst=self.stage_dict[stage_name](self.ap)
|
||||
)
|
||||
)
|
||||
|
||||
for stage_container in stage_containers:
|
||||
await stage_container.inst.initialize(pipeline_entity.config)
|
||||
|
||||
|
||||
runtime_pipeline = RuntimePipeline(self.ap, pipeline_entity, stage_containers)
|
||||
self.pipelines.append(runtime_pipeline)
|
||||
|
||||
@@ -251,4 +298,4 @@ class PipelineManager:
|
||||
for pipeline in self.pipelines:
|
||||
if pipeline.pipeline_entity.uuid == uuid:
|
||||
self.pipelines.remove(pipeline)
|
||||
return
|
||||
return
|
||||
|
||||
@@ -47,7 +47,7 @@ class QueryPool:
|
||||
message_chain=message_chain,
|
||||
resp_messages=[],
|
||||
resp_message_chain=[],
|
||||
adapter=adapter
|
||||
adapter=adapter,
|
||||
)
|
||||
self.queries.append(query)
|
||||
self.query_id_counter += 1
|
||||
|
||||
@@ -9,7 +9,7 @@ from ...plugin import events
|
||||
from ...platform.types import message as platform_message
|
||||
|
||||
|
||||
@stage.stage_class("PreProcessor")
|
||||
@stage.stage_class('PreProcessor')
|
||||
class PreProcessor(stage.PipelineStage):
|
||||
"""请求预处理阶段
|
||||
|
||||
@@ -29,11 +29,12 @@ class PreProcessor(stage.PipelineStage):
|
||||
query: core_entities.Query,
|
||||
stage_inst_name: str,
|
||||
) -> entities.StageProcessResult:
|
||||
"""处理
|
||||
"""
|
||||
"""处理"""
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
conversation = await self.ap.sess_mgr.get_conversation(query, session, query.pipeline_config['ai']['local-agent']['prompt'])
|
||||
conversation = await self.ap.sess_mgr.get_conversation(
|
||||
query, session, query.pipeline_config['ai']['local-agent']['prompt']
|
||||
)
|
||||
|
||||
# 设置query
|
||||
query.session = session
|
||||
@@ -42,17 +43,26 @@ class PreProcessor(stage.PipelineStage):
|
||||
|
||||
query.use_llm_model = conversation.use_llm_model
|
||||
|
||||
query.use_funcs = conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('tool_call') else None
|
||||
query.use_funcs = (
|
||||
conversation.use_funcs
|
||||
if query.use_llm_model.model_entity.abilities.__contains__('tool_call')
|
||||
else None
|
||||
)
|
||||
|
||||
query.variables = {
|
||||
"session_id": f"{query.session.launcher_type.value}_{query.session.launcher_id}",
|
||||
"conversation_id": conversation.uuid,
|
||||
"msg_create_time": int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp()),
|
||||
'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
'conversation_id': conversation.uuid,
|
||||
'msg_create_time': int(query.message_event.time)
|
||||
if query.message_event.time
|
||||
else int(datetime.datetime.now().timestamp()),
|
||||
}
|
||||
|
||||
# Check if this model supports vision, if not, remove all images
|
||||
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
|
||||
if query.pipeline_config['ai']['runner']['runner'] == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'):
|
||||
if (
|
||||
query.pipeline_config['ai']['runner']['runner'] == 'local-agent'
|
||||
and not query.use_llm_model.model_entity.abilities.__contains__('vision')
|
||||
):
|
||||
for msg in query.messages:
|
||||
if isinstance(msg.content, list):
|
||||
for me in msg.content:
|
||||
@@ -61,16 +71,17 @@ class PreProcessor(stage.PipelineStage):
|
||||
|
||||
content_list = []
|
||||
|
||||
plain_text = ""
|
||||
plain_text = ''
|
||||
|
||||
for me in query.message_chain:
|
||||
if isinstance(me, platform_message.Plain):
|
||||
content_list.append(
|
||||
llm_entities.ContentElement.from_text(me.text)
|
||||
)
|
||||
content_list.append(llm_entities.ContentElement.from_text(me.text))
|
||||
plain_text += me.text
|
||||
elif isinstance(me, platform_message.Image):
|
||||
if query.pipeline_config['ai']['runner']['runner'] != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__('vision'):
|
||||
if (
|
||||
query.pipeline_config['ai']['runner']['runner'] != 'local-agent'
|
||||
or query.use_llm_model.model_entity.abilities.__contains__('vision')
|
||||
):
|
||||
if me.base64 is not None:
|
||||
content_list.append(
|
||||
llm_entities.ContentElement.from_image_base64(me.base64)
|
||||
@@ -78,10 +89,7 @@ class PreProcessor(stage.PipelineStage):
|
||||
|
||||
query.variables['user_message_text'] = plain_text
|
||||
|
||||
query.user_message = llm_entities.Message(
|
||||
role='user',
|
||||
content=content_list
|
||||
)
|
||||
query.user_message = llm_entities.Message(role='user', content=content_list)
|
||||
# =========== 触发事件 PromptPreProcessing
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
@@ -89,7 +97,7 @@ class PreProcessor(stage.PipelineStage):
|
||||
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
default_prompt=query.prompt.messages,
|
||||
prompt=query.messages,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -97,6 +105,5 @@ class PreProcessor(stage.PipelineStage):
|
||||
query.messages = event_ctx.event.prompt
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
|
||||
@@ -8,7 +8,6 @@ from .. import entities
|
||||
|
||||
|
||||
class MessageHandler(metaclass=abc.ABCMeta):
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
|
||||
@@ -3,33 +3,36 @@ from __future__ import annotations
|
||||
import typing
|
||||
import time
|
||||
import traceback
|
||||
import json
|
||||
|
||||
|
||||
from .. import handler
|
||||
from ... import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....provider import entities as llm_entities
|
||||
from ....provider import runner as runner_module
|
||||
from ....provider.runners import localagent, difysvapi, dashscopeapi
|
||||
from ....plugin import events
|
||||
|
||||
from ....platform.types import message as platform_message
|
||||
from ....utils import importutil
|
||||
from ....provider import runners
|
||||
|
||||
importutil.import_modules_in_pkg(runners)
|
||||
|
||||
|
||||
class ChatMessageHandler(handler.MessageHandler):
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||
"""处理
|
||||
"""
|
||||
"""处理"""
|
||||
# 调API
|
||||
# 生成器
|
||||
|
||||
# 触发插件事件
|
||||
event_class = events.PersonNormalMessageReceived if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupNormalMessageReceived
|
||||
event_class = (
|
||||
events.PersonNormalMessageReceived
|
||||
if query.launcher_type == core_entities.LauncherTypes.PERSON
|
||||
else events.GroupNormalMessageReceived
|
||||
)
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=event_class(
|
||||
@@ -37,7 +40,7 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
text_message=str(query.message_chain),
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -48,16 +51,13 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
query.resp_messages.append(mc)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
else:
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.INTERRUPT, new_query=query
|
||||
)
|
||||
else:
|
||||
|
||||
if event_ctx.event.alter is not None:
|
||||
# if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
|
||||
query.user_message.content = event_ctx.event.alter
|
||||
@@ -67,48 +67,52 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
|
||||
for r in runner_module.preregistered_runners:
|
||||
if r.name == query.pipeline_config["ai"]["runner"]["runner"]:
|
||||
if r.name == query.pipeline_config['ai']['runner']['runner']:
|
||||
runner = r(self.ap, query.pipeline_config)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"未找到请求运行器: {query.pipeline_config['ai']['runner']['runner']}")
|
||||
raise ValueError(
|
||||
f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}'
|
||||
)
|
||||
|
||||
async for result in runner.run(query):
|
||||
query.resp_messages.append(result)
|
||||
|
||||
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
|
||||
self.ap.logger.info(
|
||||
f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}'
|
||||
)
|
||||
|
||||
if result.content is not None:
|
||||
text_length += len(result.content)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
|
||||
query.session.using_conversation.messages.append(query.user_message)
|
||||
query.session.using_conversation.messages.extend(query.resp_messages)
|
||||
except Exception as e:
|
||||
|
||||
self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}')
|
||||
self.ap.logger.error(
|
||||
f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}'
|
||||
)
|
||||
|
||||
hide_exception_info = query.pipeline_config['output']['misc']['hide-exception']
|
||||
hide_exception_info = query.pipeline_config['output']['misc'][
|
||||
'hide-exception'
|
||||
]
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
user_notice='请求失败' if hide_exception_info else f'{e}',
|
||||
error_notice=f'{e}',
|
||||
debug_notice=traceback.format_exc()
|
||||
debug_notice=traceback.format_exc(),
|
||||
)
|
||||
finally:
|
||||
|
||||
await self.ap.ctr_mgr.usage.post_query_record(
|
||||
session_type=query.session.launcher_type.value,
|
||||
session_id=str(query.session.launcher_id),
|
||||
query_ability_provider="LangBot.Chat",
|
||||
query_ability_provider='LangBot.Chat',
|
||||
usage=text_length,
|
||||
model_name=query.use_model.name,
|
||||
response_seconds=int(time.time() - start_time),
|
||||
|
||||
@@ -11,24 +11,29 @@ from ....platform.types import message as platform_message
|
||||
|
||||
|
||||
class CommandHandler(handler.MessageHandler):
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||
"""处理
|
||||
"""
|
||||
"""处理"""
|
||||
|
||||
command_text = str(query.message_chain).strip()[1:]
|
||||
|
||||
privilege = 1
|
||||
|
||||
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
|
||||
|
||||
if (
|
||||
f'{query.launcher_type.value}_{query.launcher_id}'
|
||||
in self.ap.instance_config.data['admins']
|
||||
):
|
||||
privilege = 2
|
||||
|
||||
spt = command_text.split(' ')
|
||||
|
||||
event_class = events.PersonCommandSent if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupCommandSent
|
||||
event_class = (
|
||||
events.PersonCommandSent
|
||||
if query.launcher_type == core_entities.LauncherTypes.PERSON
|
||||
else events.GroupCommandSent
|
||||
)
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=event_class(
|
||||
@@ -38,41 +43,35 @@ class CommandHandler(handler.MessageHandler):
|
||||
command=spt[0],
|
||||
params=spt[1:] if len(spt) > 1 else [],
|
||||
text_message=str(query.message_chain),
|
||||
is_admin=(privilege==2),
|
||||
query=query
|
||||
is_admin=(privilege == 2),
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
|
||||
if event_ctx.event.reply is not None:
|
||||
mc = platform_message.MessageChain(event_ctx.event.reply)
|
||||
|
||||
query.resp_messages.append(mc)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
else:
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.INTERRUPT, new_query=query
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
if event_ctx.event.alter is not None:
|
||||
query.message_chain = platform_message.MessageChain([
|
||||
platform_message.Plain(event_ctx.event.alter)
|
||||
])
|
||||
query.message_chain = platform_message.MessageChain(
|
||||
[platform_message.Plain(event_ctx.event.alter)]
|
||||
)
|
||||
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
async for ret in self.ap.cmd_mgr.execute(
|
||||
command_text=command_text,
|
||||
query=query,
|
||||
session=session
|
||||
command_text=command_text, query=query, session=session
|
||||
):
|
||||
if ret.error is not None:
|
||||
query.resp_messages.append(
|
||||
@@ -82,20 +81,18 @@ class CommandHandler(handler.MessageHandler):
|
||||
)
|
||||
)
|
||||
|
||||
self.ap.logger.info(f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}')
|
||||
self.ap.logger.info(
|
||||
f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}'
|
||||
)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
elif ret.text is not None or ret.image_url is not None:
|
||||
|
||||
content: list[llm_entities.ContentElement]= []
|
||||
content: list[llm_entities.ContentElement] = []
|
||||
|
||||
if ret.text is not None:
|
||||
content.append(
|
||||
llm_entities.ContentElement.from_text(ret.text)
|
||||
)
|
||||
content.append(llm_entities.ContentElement.from_text(ret.text))
|
||||
|
||||
if ret.image_url is not None:
|
||||
content.append(
|
||||
@@ -112,11 +109,9 @@ class CommandHandler(handler.MessageHandler):
|
||||
self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}')
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
else:
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.INTERRUPT, new_query=query
|
||||
)
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from ...core import entities as core_entities
|
||||
from . import handler
|
||||
from .handlers import chat, command
|
||||
from .. import entities
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from .. import stage
|
||||
|
||||
|
||||
@stage.stage_class("MessageProcessor")
|
||||
@stage.stage_class('MessageProcessor')
|
||||
class Processor(stage.PipelineStage):
|
||||
"""请求实际处理阶段
|
||||
|
||||
|
||||
通过命令处理器和聊天处理器处理消息。
|
||||
|
||||
改写:
|
||||
@@ -35,11 +33,12 @@ class Processor(stage.PipelineStage):
|
||||
query: core_entities.Query,
|
||||
stage_inst_name: str,
|
||||
) -> entities.StageProcessResult:
|
||||
"""处理
|
||||
"""
|
||||
"""处理"""
|
||||
message_text = str(query.message_chain).strip()
|
||||
|
||||
self.ap.logger.info(f"处理 {query.launcher_type.value}_{query.launcher_id} 的请求({query.query_id}): {message_text}")
|
||||
self.ap.logger.info(
|
||||
f'处理 {query.launcher_type.value}_{query.launcher_id} 的请求({query.query_id}): {message_text}'
|
||||
)
|
||||
|
||||
async def generator():
|
||||
cmd_prefix = self.ap.instance_config.data['command']['prefix']
|
||||
@@ -50,5 +49,5 @@ class Processor(stage.PipelineStage):
|
||||
else:
|
||||
async for result in self.chat_handler.handle(query):
|
||||
yield result
|
||||
|
||||
|
||||
return generator()
|
||||
|
||||
@@ -7,19 +7,19 @@ from ...core import app, entities as core_entities
|
||||
|
||||
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
|
||||
|
||||
|
||||
def algo_class(name: str):
|
||||
|
||||
def decorator(cls: typing.Type[ReteLimitAlgo]) -> typing.Type[ReteLimitAlgo]:
|
||||
cls.name = name
|
||||
preregistered_algos.append(cls)
|
||||
return cls
|
||||
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||
"""限流算法抽象类"""
|
||||
|
||||
|
||||
name: str = None
|
||||
|
||||
ap: app.Application
|
||||
@@ -31,11 +31,16 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def require_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
|
||||
async def require_access(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
launcher_type: str,
|
||||
launcher_id: typing.Union[int, str],
|
||||
) -> bool:
|
||||
"""进入处理流程
|
||||
|
||||
这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。
|
||||
|
||||
|
||||
Args:
|
||||
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
|
||||
launcher_id (int): 请求者ID
|
||||
@@ -44,15 +49,19 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||
bool: 是否允许进入处理流程,若返回false,则直接丢弃该请求
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
async def release_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]):
|
||||
async def release_access(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
launcher_type: str,
|
||||
launcher_id: typing.Union[int, str],
|
||||
):
|
||||
"""退出处理流程
|
||||
|
||||
Args:
|
||||
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
|
||||
launcher_id (int): 请求者ID
|
||||
"""
|
||||
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -5,9 +5,9 @@ import typing
|
||||
from .. import algo
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
# 固定窗口算法
|
||||
class SessionContainer:
|
||||
|
||||
wait_lock: asyncio.Lock
|
||||
|
||||
records: dict[int, int]
|
||||
@@ -18,9 +18,8 @@ class SessionContainer:
|
||||
self.records = {}
|
||||
|
||||
|
||||
@algo.algo_class("fixwin")
|
||||
@algo.algo_class('fixwin')
|
||||
class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
|
||||
containers_lock: asyncio.Lock
|
||||
"""访问记录容器锁"""
|
||||
|
||||
@@ -31,7 +30,12 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
self.containers_lock = asyncio.Lock()
|
||||
self.containers = {}
|
||||
|
||||
async def require_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
|
||||
async def require_access(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
launcher_type: str,
|
||||
launcher_id: typing.Union[int, str],
|
||||
) -> bool:
|
||||
# 加锁,找容器
|
||||
container: SessionContainer = None
|
||||
|
||||
@@ -46,7 +50,6 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
|
||||
# 等待锁
|
||||
async with container.wait_lock:
|
||||
|
||||
# 获取窗口大小和限制
|
||||
window_size = query.pipeline_config['safety']['rate-limit']['window-length']
|
||||
limitation = query.pipeline_config['safety']['rate-limit']['limitation']
|
||||
@@ -69,13 +72,15 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
if count >= limitation:
|
||||
if query.pipeline_config['safety']['rate-limit']['strategy'] == 'drop':
|
||||
return False
|
||||
elif query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait':
|
||||
elif (
|
||||
query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait'
|
||||
):
|
||||
# 等待下一窗口
|
||||
await asyncio.sleep(window_size - time.time() % window_size)
|
||||
|
||||
|
||||
now = int(time.time())
|
||||
now = now - now % window_size
|
||||
|
||||
|
||||
if now not in container.records:
|
||||
container.records = {}
|
||||
container.records[now] = 1
|
||||
@@ -85,6 +90,11 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
|
||||
# 返回True
|
||||
return True
|
||||
|
||||
async def release_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]):
|
||||
|
||||
async def release_access(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
launcher_type: str,
|
||||
launcher_id: typing.Union[int, str],
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -4,22 +4,25 @@ import typing
|
||||
|
||||
from .. import entities, stage
|
||||
from . import algo
|
||||
from .algos import fixedwin
|
||||
from ...core import entities as core_entities
|
||||
from ...utils import importutil
|
||||
|
||||
from . import algos
|
||||
|
||||
importutil.import_modules_in_pkg(algos)
|
||||
|
||||
|
||||
@stage.stage_class("RequireRateLimitOccupancy")
|
||||
@stage.stage_class("ReleaseRateLimitOccupancy")
|
||||
@stage.stage_class('RequireRateLimitOccupancy')
|
||||
@stage.stage_class('ReleaseRateLimitOccupancy')
|
||||
class RateLimit(stage.PipelineStage):
|
||||
"""限速器控制阶段
|
||||
|
||||
|
||||
不改写query,只检查是否需要限速。
|
||||
"""
|
||||
|
||||
algo: algo.ReteLimitAlgo
|
||||
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
|
||||
algo_name = 'fixwin'
|
||||
|
||||
algo_class = None
|
||||
@@ -42,9 +45,8 @@ class RateLimit(stage.PipelineStage):
|
||||
entities.StageProcessResult,
|
||||
typing.AsyncGenerator[entities.StageProcessResult, None],
|
||||
]:
|
||||
"""处理
|
||||
"""
|
||||
if stage_inst_name == "RequireRateLimitOccupancy":
|
||||
"""处理"""
|
||||
if stage_inst_name == 'RequireRateLimitOccupancy':
|
||||
if await self.algo.require_access(
|
||||
query,
|
||||
query.launcher_type.value,
|
||||
@@ -58,10 +60,10 @@ class RateLimit(stage.PipelineStage):
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
console_notice=f"根据限速规则忽略 {query.launcher_type.value}:{query.launcher_id} 消息",
|
||||
user_notice=f"请求数超过限速器设定值,已丢弃本消息。"
|
||||
console_notice=f'根据限速规则忽略 {query.launcher_type.value}:{query.launcher_id} 消息',
|
||||
user_notice='请求数超过限速器设定值,已丢弃本消息。',
|
||||
)
|
||||
elif stage_inst_name == "ReleaseRateLimitOccupancy":
|
||||
elif stage_inst_name == 'ReleaseRateLimitOccupancy':
|
||||
await self.algo.release_access(
|
||||
query,
|
||||
query.launcher_type.value,
|
||||
|
||||
@@ -4,41 +4,38 @@ import random
|
||||
import asyncio
|
||||
|
||||
|
||||
from ...core import app
|
||||
from ...platform.types import events as platform_events
|
||||
from ...platform.types import message as platform_message
|
||||
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
|
||||
@stage.stage_class("SendResponseBackStage")
|
||||
@stage.stage_class('SendResponseBackStage')
|
||||
class SendResponseBackStage(stage.PipelineStage):
|
||||
"""发送响应消息
|
||||
"""
|
||||
"""发送响应消息"""
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理
|
||||
"""
|
||||
async def process(
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
"""处理"""
|
||||
|
||||
random_range = (query.pipeline_config['output']['force-delay']['min'], query.pipeline_config['output']['force-delay']['max'])
|
||||
random_range = (
|
||||
query.pipeline_config['output']['force-delay']['min'],
|
||||
query.pipeline_config['output']['force-delay']['max'],
|
||||
)
|
||||
|
||||
random_delay = random.uniform(*random_range)
|
||||
|
||||
self.ap.logger.debug(
|
||||
"根据规则强制延迟回复: %s s",
|
||||
random_delay
|
||||
)
|
||||
self.ap.logger.debug('根据规则强制延迟回复: %s s', random_delay)
|
||||
|
||||
await asyncio.sleep(random_delay)
|
||||
|
||||
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage):
|
||||
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
|
||||
query.message_event, platform_events.GroupMessage
|
||||
):
|
||||
query.resp_message_chain[-1].insert(
|
||||
0,
|
||||
platform_message.At(
|
||||
query.message_event.sender.id
|
||||
)
|
||||
0, platform_message.At(query.message_event.sender.id)
|
||||
)
|
||||
|
||||
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
|
||||
@@ -46,10 +43,9 @@ class SendResponseBackStage(stage.PipelineStage):
|
||||
await query.adapter.reply_message(
|
||||
message_source=query.message_event,
|
||||
message=query.resp_message_chain[-1],
|
||||
quote_origin=quote_origin
|
||||
quote_origin=quote_origin,
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ from ...platform.types import message as platform_message
|
||||
|
||||
|
||||
class RuleJudgeResult(pydantic.BaseModel):
|
||||
|
||||
matching: bool = False
|
||||
|
||||
replacement: platform_message.MessageChain = None
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
from ...core import app
|
||||
from . import entities as rule_entities, rule
|
||||
from .rules import atbot, prefix, regexp, random
|
||||
from . import rule
|
||||
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from ...utils import importutil
|
||||
|
||||
from . import rules
|
||||
|
||||
importutil.import_modules_in_pkg(rules)
|
||||
|
||||
|
||||
@stage.stage_class("GroupRespondRuleCheckStage")
|
||||
@stage.stage_class('GroupRespondRuleCheckStage')
|
||||
class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
"""群组响应规则检查器
|
||||
|
||||
@@ -21,8 +23,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
"""检查器实例"""
|
||||
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
"""初始化检查器
|
||||
"""
|
||||
"""初始化检查器"""
|
||||
|
||||
self.rule_matchers = []
|
||||
|
||||
@@ -31,12 +32,12 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
await rule_inst.initialize()
|
||||
self.rule_matchers.append(rule_inst)
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
if query.launcher_type.value != 'group': # 只处理群消息
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
|
||||
rules = query.pipeline_config['trigger']['group-respond-rules']
|
||||
@@ -48,7 +49,9 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
# use_rule = rules[str(query.launcher_id)]
|
||||
|
||||
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
|
||||
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)
|
||||
res = await rule_matcher.match(
|
||||
str(query.message_chain), query.message_chain, use_rule, query
|
||||
)
|
||||
if res.matching:
|
||||
query.message_chain = res.replacement
|
||||
|
||||
@@ -56,8 +59,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query,
|
||||
)
|
||||
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.INTERRUPT, new_query=query
|
||||
)
|
||||
|
||||
@@ -10,17 +10,19 @@ from ...platform.types import message as platform_message
|
||||
|
||||
preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
|
||||
|
||||
|
||||
def rule_class(name: str):
|
||||
def decorator(cls: typing.Type[GroupRespondRule]) -> typing.Type[GroupRespondRule]:
|
||||
cls.name = name
|
||||
preregisetered_rules.append(cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class GroupRespondRule(metaclass=abc.ABCMeta):
|
||||
"""群组响应规则的抽象类
|
||||
"""
|
||||
"""群组响应规则的抽象类"""
|
||||
|
||||
name: str
|
||||
|
||||
ap: app.Application
|
||||
@@ -37,8 +39,7 @@ class GroupRespondRule(metaclass=abc.ABCMeta):
|
||||
message_text: str,
|
||||
message_chain: platform_message.MessageChain,
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query
|
||||
query: core_entities.Query,
|
||||
) -> entities.RuleJudgeResult:
|
||||
"""判断消息是否匹配规则
|
||||
"""
|
||||
"""判断消息是否匹配规则"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -7,21 +7,24 @@ from ....core import entities as core_entities
|
||||
from ....platform.types import message as platform_message
|
||||
|
||||
|
||||
@rule_model.rule_class("at-bot")
|
||||
@rule_model.rule_class('at-bot')
|
||||
class AtBotRule(rule_model.GroupRespondRule):
|
||||
|
||||
async def match(
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: platform_message.MessageChain,
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query
|
||||
query: core_entities.Query,
|
||||
) -> entities.RuleJudgeResult:
|
||||
|
||||
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
|
||||
if (
|
||||
message_chain.has(platform_message.At(query.adapter.bot_account_id))
|
||||
and rule_dict['at']
|
||||
):
|
||||
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
|
||||
|
||||
if message_chain.has(platform_message.At(query.adapter.bot_account_id)): # 回复消息时会at两次,检查并删除重复的
|
||||
if message_chain.has(
|
||||
platform_message.At(query.adapter.bot_account_id)
|
||||
): # 回复消息时会at两次,检查并删除重复的
|
||||
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
@@ -29,7 +32,4 @@ class AtBotRule(rule_model.GroupRespondRule):
|
||||
replacement=message_chain,
|
||||
)
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
matching=False,
|
||||
replacement = message_chain
|
||||
)
|
||||
return entities.RuleJudgeResult(matching=False, replacement=message_chain)
|
||||
|
||||
@@ -1,36 +1,30 @@
|
||||
|
||||
from .. import rule as rule_model
|
||||
from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....platform.types import message as platform_message
|
||||
|
||||
|
||||
@rule_model.rule_class("prefix")
|
||||
@rule_model.rule_class('prefix')
|
||||
class PrefixRule(rule_model.GroupRespondRule):
|
||||
|
||||
async def match(
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: platform_message.MessageChain,
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query
|
||||
query: core_entities.Query,
|
||||
) -> entities.RuleJudgeResult:
|
||||
prefixes = rule_dict['prefix']
|
||||
|
||||
for prefix in prefixes:
|
||||
if message_text.startswith(prefix):
|
||||
|
||||
# 查找第一个plain元素
|
||||
for me in message_chain:
|
||||
if isinstance(me, platform_message.Plain):
|
||||
me.text = me.text[len(prefix):]
|
||||
me.text = me.text[len(prefix) :]
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
matching=True,
|
||||
replacement=message_chain,
|
||||
)
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
matching=False,
|
||||
replacement=message_chain
|
||||
)
|
||||
return entities.RuleJudgeResult(matching=False, replacement=message_chain)
|
||||
|
||||
@@ -7,19 +7,17 @@ from ....core import entities as core_entities
|
||||
from ....platform.types import message as platform_message
|
||||
|
||||
|
||||
@rule_model.rule_class("random")
|
||||
@rule_model.rule_class('random')
|
||||
class RandomRespRule(rule_model.GroupRespondRule):
|
||||
|
||||
async def match(
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: platform_message.MessageChain,
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query
|
||||
query: core_entities.Query,
|
||||
) -> entities.RuleJudgeResult:
|
||||
random_rate = rule_dict['random']
|
||||
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
matching=random.random() < random_rate,
|
||||
replacement=message_chain
|
||||
)
|
||||
matching=random.random() < random_rate, replacement=message_chain
|
||||
)
|
||||
|
||||
@@ -7,15 +7,14 @@ from ....core import entities as core_entities
|
||||
from ....platform.types import message as platform_message
|
||||
|
||||
|
||||
@rule_model.rule_class("regexp")
|
||||
@rule_model.rule_class('regexp')
|
||||
class RegExpRule(rule_model.GroupRespondRule):
|
||||
|
||||
async def match(
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: platform_message.MessageChain,
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query
|
||||
query: core_entities.Query,
|
||||
) -> entities.RuleJudgeResult:
|
||||
regexps = rule_dict['regexp']
|
||||
|
||||
@@ -27,8 +26,5 @@ class RegExpRule(rule_model.GroupRespondRule):
|
||||
matching=True,
|
||||
replacement=message_chain,
|
||||
)
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
matching=False,
|
||||
replacement=message_chain
|
||||
)
|
||||
|
||||
return entities.RuleJudgeResult(matching=False, replacement=message_chain)
|
||||
|
||||
@@ -11,17 +11,15 @@ preregistered_stages: dict[str, PipelineStage] = {}
|
||||
|
||||
|
||||
def stage_class(name: str):
|
||||
|
||||
def decorator(cls):
|
||||
preregistered_stages[name] = cls
|
||||
return cls
|
||||
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class PipelineStage(metaclass=abc.ABCMeta):
|
||||
"""流水线阶段
|
||||
"""
|
||||
"""流水线阶段"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
@@ -29,8 +27,7 @@ class PipelineStage(metaclass=abc.ABCMeta):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
"""初始化
|
||||
"""
|
||||
"""初始化"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -42,6 +39,5 @@ class PipelineStage(metaclass=abc.ABCMeta):
|
||||
entities.StageProcessResult,
|
||||
typing.AsyncGenerator[entities.StageProcessResult, None],
|
||||
]:
|
||||
"""处理
|
||||
"""
|
||||
"""处理"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -3,21 +3,19 @@ from __future__ import annotations
|
||||
import typing
|
||||
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from .. import entities
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from .. import entities
|
||||
from .. import stage
|
||||
from ...plugin import events
|
||||
from ...platform.types import message as platform_message
|
||||
|
||||
|
||||
@stage.stage_class("ResponseWrapper")
|
||||
@stage.stage_class('ResponseWrapper')
|
||||
class ResponseWrapper(stage.PipelineStage):
|
||||
"""回复包装阶段
|
||||
|
||||
把回复的 message 包装成人类识读的形式。
|
||||
|
||||
|
||||
改写:
|
||||
- resp_message_chain
|
||||
"""
|
||||
@@ -30,36 +28,36 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
query: core_entities.Query,
|
||||
stage_inst_name: str,
|
||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||
"""处理
|
||||
"""
|
||||
"""处理"""
|
||||
|
||||
# 如果 resp_messages[-1] 已经是 MessageChain 了
|
||||
if isinstance(query.resp_messages[-1], platform_message.MessageChain):
|
||||
query.resp_message_chain.append(query.resp_messages[-1])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
if query.resp_messages[-1].role == 'command':
|
||||
query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain(prefix_text='[bot] '))
|
||||
query.resp_message_chain.append(
|
||||
query.resp_messages[-1].get_content_platform_message_chain(
|
||||
prefix_text='[bot] '
|
||||
)
|
||||
)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
elif query.resp_messages[-1].role == 'plugin':
|
||||
query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain())
|
||||
query.resp_message_chain.append(
|
||||
query.resp_messages[-1].get_content_platform_message_chain()
|
||||
)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
else:
|
||||
|
||||
if query.resp_messages[-1].role == 'assistant':
|
||||
result = query.resp_messages[-1]
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
@@ -79,39 +77,51 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
|
||||
query=query
|
||||
funcs_called=[
|
||||
fc.function.name for fc in result.tool_calls
|
||||
]
|
||||
if result.tool_calls is not None
|
||||
else [],
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
if event_ctx.is_prevented_default():
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
new_query=query,
|
||||
)
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
|
||||
query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
|
||||
query.resp_message_chain.append(
|
||||
platform_message.MessageChain(event_ctx.event.reply)
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
query.resp_message_chain.append(result.get_content_platform_message_chain())
|
||||
query.resp_message_chain.append(
|
||||
result.get_content_platform_message_chain()
|
||||
)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
new_query=query,
|
||||
)
|
||||
|
||||
if result.tool_calls is not None and len(result.tool_calls) > 0: # 有函数调用
|
||||
|
||||
if (
|
||||
result.tool_calls is not None and len(result.tool_calls) > 0
|
||||
): # 有函数调用
|
||||
function_names = [tc.function.name for tc in result.tool_calls]
|
||||
|
||||
reply_text = f'调用函数 {".".join(function_names)}...'
|
||||
|
||||
query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
|
||||
query.resp_message_chain.append(
|
||||
platform_message.MessageChain(
|
||||
[platform_message.Plain(reply_text)]
|
||||
)
|
||||
)
|
||||
|
||||
if query.pipeline_config['output']['misc']['track-function-calls']:
|
||||
|
||||
if query.pipeline_config['output']['misc'][
|
||||
'track-function-calls'
|
||||
]:
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
@@ -121,26 +131,36 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
|
||||
query=query
|
||||
funcs_called=[
|
||||
fc.function.name for fc in result.tool_calls
|
||||
]
|
||||
if result.tool_calls is not None
|
||||
else [],
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
new_query=query,
|
||||
)
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
|
||||
query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
|
||||
query.resp_message_chain.append(
|
||||
platform_message.MessageChain(
|
||||
event_ctx.event.reply
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
|
||||
query.resp_message_chain.append(
|
||||
platform_message.MessageChain(
|
||||
[platform_message.Plain(reply_text)]
|
||||
)
|
||||
)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
new_query=query,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user