refactor: 将 yirimirai 的组件集成进 platform 包

This commit is contained in:
RockChinQ
2024-09-26 00:23:03 +08:00
parent ee0d6dcdae
commit 1c4a700d92
36 changed files with 1580 additions and 342 deletions
+9 -6
View File
@@ -1,8 +1,8 @@
from __future__ import annotations
import mirai
import mirai.models
import mirai.models.message
# import mirai
# import mirai.models
# import mirai.models.message
from ...core import app
@@ -12,6 +12,9 @@ from ...config import manager as cfg_mgr
from . import filter as filter_model, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine
from ...provider import entities as llm_entities
from ...platform.types import message as platform_message
from ...platform.types import events as platform_events
from ...platform.types import entities as platform_entities
@stage.stage_class('PostContentFilterStage')
@@ -89,8 +92,8 @@ class ContentFilterStage(stage.PipelineStage):
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
message = result.replacement
query.message_chain = mirai.MessageChain(
mirai.Plain(message)
query.message_chain = platform_message.MessageChain(
platform_message.Plain(message)
)
return entities.StageProcessResult(
@@ -148,7 +151,7 @@ class ContentFilterStage(stage.PipelineStage):
contain_non_text = False
text_components = [mirai.Plain, mirai.models.message.Source]
text_components = [platform_message.Plain, platform_message.Source]
for me in query.message_chain:
if type(me) not in text_components:
+5 -4
View File
@@ -4,11 +4,12 @@ import asyncio
import typing
import traceback
import mirai
# import mirai
from ..core import app, entities
from . import entities as pipeline_entities
from ..plugin import events
from ..platform.types import message as platform_message
class Controller:
@@ -73,11 +74,11 @@ class Controller:
# 处理str类型
if isinstance(result.user_notice, str):
result.user_notice = mirai.MessageChain(
mirai.Plain(result.user_notice)
result.user_notice = platform_message.MessageChain(
platform_message.Plain(result.user_notice)
)
elif isinstance(result.user_notice, list):
result.user_notice = mirai.MessageChain(
result.user_notice = platform_message.MessageChain(
*result.user_notice
)
+3 -3
View File
@@ -4,8 +4,8 @@ import enum
import typing
import pydantic
import mirai
import mirai.models.message as mirai_message
# import mirai
from ..platform.types import message as platform_message
from ..core import entities
@@ -25,7 +25,7 @@ class StageProcessResult(pydantic.BaseModel):
new_query: entities.Query
user_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = []
user_notice: typing.Optional[typing.Union[str, list[platform_message.MessageComponent], platform_message.MessageChain, None]] = []
"""只要设置了就会发送给用户"""
# TODO delete
+4 -3
View File
@@ -3,7 +3,7 @@ import os
import traceback
from PIL import Image, ImageDraw, ImageFont
from mirai.models.message import MessageComponent, Plain, MessageChain
# from mirai.models.message import MessageComponent, Plain, MessageChain
from ...core import app
from . import strategy
@@ -11,6 +11,7 @@ from .strategies import image, forward
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from ...platform.types import message as platform_message
@stage.stage_class("LongTextProcessStage")
@@ -63,14 +64,14 @@ class LongTextProcessStage(stage.PipelineStage):
contains_non_plain = False
for msg in query.resp_message_chain[-1]:
if not isinstance(msg, Plain):
if not isinstance(msg, platform_message.Plain):
contains_non_plain = True
break
if contains_non_plain:
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']:
query.resp_message_chain[-1] = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
query.resp_message_chain[-1] = platform_message.MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
+11 -9
View File
@@ -2,15 +2,17 @@
from __future__ import annotations
import typing
from mirai.models import MessageChain
from mirai.models.message import MessageComponent, ForwardMessageNode
from mirai.models.base import MiraiBaseModel
# from mirai.models import MessageChain
# from mirai.models.message import MessageComponent, ForwardMessageNode
# from mirai.models.base import MiraiBaseModel
import pydantic
from .. import strategy as strategy_model
from ....core import entities as core_entities
from ....platform.types import message as platform_message
class ForwardMessageDiaplay(MiraiBaseModel):
class ForwardMessageDiaplay(pydantic.BaseModel):
title: str = "群聊的聊天记录"
brief: str = "[聊天记录]"
source: str = "聊天记录"
@@ -18,13 +20,13 @@ class ForwardMessageDiaplay(MiraiBaseModel):
summary: str = "查看x条转发消息"
class Forward(MessageComponent):
class Forward(platform_message.MessageComponent):
"""合并转发。"""
type: str = "Forward"
"""消息组件类型。"""
display: ForwardMessageDiaplay
"""显示信息"""
node_list: typing.List[ForwardMessageNode]
node_list: typing.List[platform_message.ForwardMessageNode]
"""转发消息节点列表。"""
def __init__(self, *args, **kwargs):
if len(args) == 1:
@@ -39,7 +41,7 @@ class Forward(MessageComponent):
@strategy_model.strategy_class("forward")
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
display = ForwardMessageDiaplay(
title="群聊的聊天记录",
brief="[聊天记录]",
@@ -49,10 +51,10 @@ class ForwardComponentStrategy(strategy_model.LongTextStrategy):
)
node_list = [
ForwardMessageNode(
platform_message.ForwardMessageNode(
sender_id=query.adapter.bot_account_id,
sender_name='QQ用户',
message_chain=MessageChain([message])
message_chain=platform_message.MessageChain([message])
)
]
+5 -4
View File
@@ -8,8 +8,9 @@ import re
from PIL import Image, ImageDraw, ImageFont
from mirai.models import MessageChain, Image as ImageComponent
from mirai.models.message import MessageComponent
# from mirai.models import MessageChain, Image as ImageComponent
# from mirai.models.message import MessageComponent
from ....platform.types import message as platform_message
from .. import strategy as strategy_model
from ....core import entities as core_entities
@@ -23,7 +24,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
async def initialize(self):
self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8")
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
img_path = self.text_to_image(
text_str=message,
save_as='temp/{}.png'.format(int(time.time()))
@@ -46,7 +47,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
os.remove(compressed_path)
return [
ImageComponent(
platform_message.Image(
base64=b64.decode('utf-8'),
)
]
+5 -4
View File
@@ -2,11 +2,12 @@ from __future__ import annotations
import abc
import typing
import mirai
from mirai.models.message import MessageComponent
# import mirai
# from mirai.models.message import MessageComponent
from ...core import app
from ...core import entities as core_entities
from ...platform.types import message as platform_message
preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
@@ -51,7 +52,7 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
"""处理长文本
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法
@@ -61,6 +62,6 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
query (core_entities.Query): 此次请求的上下文对象
Returns:
list[mirai.models.messages.MessageComponent]: 转换后的 YiriMirai 消息组件列表
list[platform_message.MessageComponent]: 转换后的 平台 消息组件列表
"""
return []
+5 -3
View File
@@ -2,10 +2,12 @@ from __future__ import annotations
import asyncio
import mirai
# import mirai
from ..core import entities
from ..platform import adapter as msadapter
from ..platform.types import message as platform_message
from ..platform.types import events as platform_events
class QueryPool:
@@ -30,8 +32,8 @@ class QueryPool:
launcher_type: entities.LauncherTypes,
launcher_id: int,
sender_id: int,
message_event: mirai.MessageEvent,
message_chain: mirai.MessageChain,
message_event: platform_events.MessageEvent,
message_chain: platform_message.MessageChain,
adapter: msadapter.MessageSourceAdapter
) -> entities.Query:
async with self.condition:
+4 -3
View File
@@ -1,11 +1,12 @@
from __future__ import annotations
import mirai
# import mirai
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...provider import entities as llm_entities
from ...plugin import events
from ...platform.types import message as platform_message
@stage.stage_class("PreProcessor")
@@ -55,11 +56,11 @@ class PreProcessor(stage.PipelineStage):
content_list = []
for me in query.message_chain:
if isinstance(me, mirai.Plain):
if isinstance(me, platform_message.Plain):
content_list.append(
llm_entities.ContentElement.from_text(me.text)
)
elif isinstance(me, mirai.Image):
elif isinstance(me, platform_message.Image):
if self.ap.provider_cfg.data['enable-vision'] and query.use_model.vision_supported:
if me.url is not None:
content_list.append(
+4 -2
View File
@@ -5,7 +5,7 @@ import time
import traceback
import json
import mirai
# import mirai
from .. import handler
from ... import entities
@@ -13,6 +13,8 @@ from ....core import entities as core_entities
from ....provider import entities as llm_entities, runnermgr
from ....plugin import events
from ....platform.types import message as platform_message
class ChatMessageHandler(handler.MessageHandler):
@@ -40,7 +42,7 @@ class ChatMessageHandler(handler.MessageHandler):
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
mc = mirai.MessageChain(event_ctx.event.reply)
mc = platform_message.MessageChain(event_ctx.event.reply)
query.resp_messages.append(mc)
+5 -4
View File
@@ -1,13 +1,14 @@
from __future__ import annotations
import typing
import mirai
# import mirai
from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....provider import entities as llm_entities
from ....plugin import events
from ....platform.types import message as platform_message
class CommandHandler(handler.MessageHandler):
@@ -46,7 +47,7 @@ class CommandHandler(handler.MessageHandler):
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
mc = mirai.MessageChain(event_ctx.event.reply)
mc = platform_message.MessageChain(event_ctx.event.reply)
query.resp_messages.append(mc)
@@ -63,8 +64,8 @@ class CommandHandler(handler.MessageHandler):
else:
if event_ctx.event.alter is not None:
query.message_chain = mirai.MessageChain([
mirai.Plain(event_ctx.event.alter)
query.message_chain = platform_message.MessageChain([
platform_message.Plain(event_ctx.event.alter)
])
session = await self.ap.sess_mgr.get_session(query)
+1 -1
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
import random
import asyncio
import mirai
# import mirai
from ...core import app
+4 -2
View File
@@ -1,9 +1,11 @@
import pydantic
import mirai
# import mirai
from ...platform.types import message as platform_message
class RuleJudgeResult(pydantic.BaseModel):
matching: bool = False
replacement: mirai.MessageChain = None
replacement: platform_message.MessageChain = None
+1 -1
View File
@@ -1,6 +1,6 @@
from __future__ import annotations
import mirai
# import mirai
from ...core import app
from . import entities as rule_entities, rule
+4 -2
View File
@@ -2,11 +2,13 @@ from __future__ import annotations
import abc
import typing
import mirai
# import mirai
from ...core import app, entities as core_entities
from . import entities
from ...platform.types import message as platform_message
preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
@@ -35,7 +37,7 @@ class GroupRespondRule(metaclass=abc.ABCMeta):
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
+7 -6
View File
@@ -1,10 +1,11 @@
from __future__ import annotations
import mirai
# import mirai
from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("at-bot")
@@ -13,16 +14,16 @@ class AtBotRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']:
message_chain.remove(mirai.At(query.adapter.bot_account_id))
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
if message_chain.has(mirai.At(query.adapter.bot_account_id)): # 回复消息时会at两次,检查并删除重复的
message_chain.remove(mirai.At(query.adapter.bot_account_id))
if message_chain.has(platform_message.At(query.adapter.bot_account_id)): # 回复消息时会at两次,检查并删除重复的
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
return entities.RuleJudgeResult(
matching=True,
+4 -3
View File
@@ -1,8 +1,9 @@
import mirai
# import mirai
from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("prefix")
@@ -11,7 +12,7 @@ class PrefixRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
@@ -22,7 +23,7 @@ class PrefixRule(rule_model.GroupRespondRule):
# 查找第一个plain元素
for me in message_chain:
if isinstance(me, mirai.Plain):
if isinstance(me, platform_message.Plain):
me.text = me.text[len(prefix):]
return entities.RuleJudgeResult(
+3 -2
View File
@@ -1,10 +1,11 @@
import random
import mirai
# import mirai
from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("random")
@@ -13,7 +14,7 @@ class RandomRespRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
+3 -2
View File
@@ -1,10 +1,11 @@
import re
import mirai
# import mirai
from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("regexp")
@@ -13,7 +14,7 @@ class RegExpRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
+7 -6
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
import typing
import mirai
# import mirai
from ...core import app, entities as core_entities
from .. import entities
@@ -10,6 +10,7 @@ from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from ...plugin import events
from ...platform.types import message as platform_message
@stage.stage_class("ResponseWrapper")
@@ -34,7 +35,7 @@ class ResponseWrapper(stage.PipelineStage):
"""
# 如果 resp_messages[-1] 已经是 MessageChain 了
if isinstance(query.resp_messages[-1], mirai.MessageChain):
if isinstance(query.resp_messages[-1], platform_message.MessageChain):
query.resp_message_chain.append(query.resp_messages[-1])
yield entities.StageProcessResult(
@@ -96,7 +97,7 @@ class ResponseWrapper(stage.PipelineStage):
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
else:
@@ -113,7 +114,7 @@ class ResponseWrapper(stage.PipelineStage):
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
if self.ap.platform_cfg.data['track-function-calls']:
@@ -139,11 +140,11 @@ class ResponseWrapper(stage.PipelineStage):
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
else:
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,