refactor: 请求处理控制流基础架构

This commit is contained in:
RockChinQ
2024-01-26 15:51:49 +08:00
parent a064c24f60
commit 8d084427d2
55 changed files with 1430 additions and 146 deletions
View File
+57
View File
@@ -0,0 +1,57 @@
from __future__ import annotations
import os
import traceback
from PIL import Image, ImageDraw, ImageFont
from mirai.models.message import MessageComponent, Plain, MessageChain
from ...core import app
from . import strategy
from .strategies import image, forward
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@stage.stage_class("LongTextProcessStage")
class LongTextProcessStage(stage.PipelineStage):
strategy_impl: strategy.LongTextStrategy
async def initialize(self):
config = self.ap.cfg_mgr.data
if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image':
use_font = config['font_path']
try:
# 检查是否存在
if not os.path.exists(use_font):
# 若是windows系统,使用微软雅黑
if os.name == "nt":
use_font = "C:/Windows/Fonts/msyh.ttc"
if not os.path.exists(use_font):
self.ap.logger.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。")
config['blob_message_strategy'] = "forward"
else:
self.ap.logger.info("使用Windows自带字体:" + use_font)
self.ap.cfg_mgr.data['font_path'] = use_font
else:
self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。")
self.ap.cfg_mgr.data['blob_message_strategy'] = "forward"
except:
traceback.print_exc()
self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font))
self.ap.cfg_mgr.data['blob_message_strategy'] = "forward"
if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image':
self.strategy_impl = image.Text2ImageStrategy(self.ap)
elif self.ap.cfg_mgr.data['blob_message_strategy'] == 'forward':
self.strategy_impl = forward.ForwardComponentStrategy(self.ap)
await self.strategy_impl.initialize()
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if len(str(query.resp_message_chain)) > self.ap.cfg_mgr.data['blob_message_threshold']:
query.message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain)))
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
@@ -0,0 +1,62 @@
# 转发消息组件
from __future__ import annotations
import typing
from mirai.models import MessageChain
from mirai.models.message import MessageComponent, ForwardMessageNode
from mirai.models.base import MiraiBaseModel
from .. import strategy as strategy_model
class ForwardMessageDiaplay(MiraiBaseModel):
title: str = "群聊的聊天记录"
brief: str = "[聊天记录]"
source: str = "聊天记录"
preview: typing.List[str] = []
summary: str = "查看x条转发消息"
class Forward(MessageComponent):
"""合并转发。"""
type: str = "Forward"
"""消息组件类型。"""
display: ForwardMessageDiaplay
"""显示信息"""
node_list: typing.List[ForwardMessageNode]
"""转发消息节点列表。"""
def __init__(self, *args, **kwargs):
if len(args) == 1:
self.node_list = args[0]
super().__init__(**kwargs)
super().__init__(*args, **kwargs)
def __str__(self):
return '[聊天记录]'
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str) -> list[MessageComponent]:
display = ForwardMessageDiaplay(
title="群聊的聊天记录",
brief="[聊天记录]",
source="聊天记录",
preview=["QQ用户: "+message],
summary="查看1条转发消息"
)
node_list = [
ForwardMessageNode(
sender_id=self.ap.im_mgr.bot_account_id,
sender_name='QQ用户',
message_chain=MessageChain([message])
)
]
forward = Forward(
display=display,
node_list=node_list
)
return [forward]
+197
View File
@@ -0,0 +1,197 @@
from __future__ import annotations
import typing
import os
import base64
import time
import re
from PIL import Image, ImageDraw, ImageFont
from mirai.models import MessageChain, Image as ImageComponent
from mirai.models.message import MessageComponent
from .. import strategy as strategy_model
class Text2ImageStrategy(strategy_model.LongTextStrategy):
text_render_font: ImageFont.FreeTypeFont
async def initialize(self):
self.text_render_font = ImageFont.truetype(self.ap.cfg_mgr.data['font_path'], 32, encoding="utf-8")
async def process(self, message: str) -> list[MessageComponent]:
img_path = self.text_to_image(
text_str=message,
save_as='temp/{}.png'.format(int(time.time()))
)
compressed_path, size = self.compress_image(
img_path,
outfile="temp/{}_compressed.png".format(int(time.time()))
)
with open(compressed_path, 'rb') as f:
img = f.read()
b64 = base64.b64encode(img)
# 删除图片
os.remove(img_path)
if os.path.exists(compressed_path):
os.remove(compressed_path)
return [
ImageComponent(
base64=b64.decode('utf-8'),
)
]
def indexNumber(self, path=''):
"""
查找字符串中数字所在串中的位置
:param path:目标字符串
:return:<class 'list'>: <class 'list'>: [['1', 16], ['2', 35], ['1', 51]]
"""
kv = []
nums = []
beforeDatas = re.findall('[\d]+', path)
for num in beforeDatas:
indexV = []
times = path.count(num)
if times > 1:
if num not in nums:
indexs = re.finditer(num, path)
for index in indexs:
iV = []
i = index.span()[0]
iV.append(num)
iV.append(i)
kv.append(iV)
nums.append(num)
else:
index = path.find(num)
indexV.append(num)
indexV.append(index)
kv.append(indexV)
# 根据数字位置排序
indexSort = []
resultIndex = []
for vi in kv:
indexSort.append(vi[1])
indexSort.sort()
for i in indexSort:
for v in kv:
if i == v[1]:
resultIndex.append(v)
return resultIndex
def get_size(self, file):
# 获取文件大小:KB
size = os.path.getsize(file)
return size / 1024
def get_outfile(self, infile, outfile):
if outfile:
return outfile
dir, suffix = os.path.splitext(infile)
outfile = '{}-out{}'.format(dir, suffix)
return outfile
def compress_image(self, infile, outfile='', kb=100, step=20, quality=90):
"""不改变图片尺寸压缩到指定大小
:param infile: 压缩源文件
:param outfile: 压缩文件保存地址
:param mb: 压缩目标,KB
:param step: 每次调整的压缩比率
:param quality: 初始压缩比率
:return: 压缩文件地址,压缩文件大小
"""
o_size = self.get_size(infile)
if o_size <= kb:
return infile, o_size
outfile = self.get_outfile(infile, outfile)
while o_size > kb:
im = Image.open(infile)
im.save(outfile, quality=quality)
if quality - step < 0:
break
quality -= step
o_size = self.get_size(outfile)
return outfile, self.get_size(outfile)
def text_to_image(self, text_str: str, save_as="temp.png", width=800):
text_str = text_str.replace("\t", " ")
# 分行
lines = text_str.split('\n')
# 计算并分割
final_lines = []
text_width = width-80
self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width))
for line in lines:
# 如果长了就分割
line_width = self.text_render_font.getlength(line)
self.ap.logger.debug("line_width: {}".format(line_width))
if line_width < text_width:
final_lines.append(line)
continue
else:
rest_text = line
while True:
# 分割最前面的一行
point = int(len(rest_text) * (text_width / line_width))
# 检查断点是否在数字中间
numbers = self.indexNumber(rest_text)
for number in numbers:
if number[1] < point < number[1] + len(number[0]) and number[1] != 0:
point = number[1]
break
final_lines.append(rest_text[:point])
rest_text = rest_text[point:]
line_width = self.text_render_font.getlength(rest_text)
if line_width < text_width:
final_lines.append(rest_text)
break
else:
continue
# 准备画布
img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255))
draw = ImageDraw.Draw(img, mode='RGBA')
self.ap.logger.debug("正在绘制图片...")
# 绘制正文
line_number = 0
offset_x = 20
offset_y = 30
for final_line in final_lines:
draw.text((offset_x, offset_y + 35 * line_number), final_line, fill=(0, 0, 0), font=self.text_render_font)
# 遍历此行,检查是否有emoji
idx_in_line = 0
for ch in final_line:
# 检查字符占位宽
char_code = ord(ch)
if char_code >= 127:
idx_in_line += 1
else:
idx_in_line += 0.5
line_number += 1
self.ap.logger.debug("正在保存图片...")
img.save(save_as)
return save_as
+22
View File
@@ -0,0 +1,22 @@
from __future__ import annotations
import abc
import typing
import mirai
from mirai.models.message import MessageComponent
from ...core import app
class LongTextStrategy(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def process(self, message: str) -> list[MessageComponent]:
return []