diff --git a/config-template.py b/config-template.py index fec0a7ef..5e6996ed 100644 --- a/config-template.py +++ b/config-template.py @@ -133,7 +133,7 @@ prompt_submit_length = 1024 completion_api_params = { "model": "gpt-3.5-turbo", "temperature": 0.9, # 数值越低得到的回答越理性,取值范围[0, 1] - "max_tokens": 512, # 每次获取OpenAI接口响应的文字量上限, 不高于4096 + "max_tokens": 1024, # 每次获取OpenAI接口响应的文字量上限, 不高于4096 "top_p": 1, # 生成的文本的文本与要求的符合度, 取值范围[0, 1] "frequency_penalty": 0.2, "presence_penalty": 1.0, @@ -154,13 +154,18 @@ include_image_description = True # 消息处理的超时时间,单位为秒 process_message_timeout = 30 -# [暂未实现] 群内会话是否启用多对象名称 -# 若不启用,群内会话的prompt只使用user_name和bot_name -multi_subject = False - # 回复消息时是否显示[GPT]前缀 show_prefix = False +# 应用长消息处理策略的阈值 +# 当回复消息长度超过此值时,将使用长消息处理策略 +blob_message_threshold = 256 + +# 长消息处理策略 +# - "image": 将长消息转换为图片发送 +# - "forward": 将长消息转换为转发消息组件发送 +blob_message_strategy = "forward" + # 消息处理超时重试次数 retry_times = 3 diff --git a/pkg/qqbot/blob.py b/pkg/qqbot/blob.py new file mode 100644 index 00000000..c6edff2e --- /dev/null +++ b/pkg/qqbot/blob.py @@ -0,0 +1,105 @@ +# 长消息处理相关 +import logging +import os +import time +import base64 + +import config +from mirai.models.message import MessageComponent, MessageChain, Image +from mirai.models.message import ForwardMessageNode +from mirai.models.base import MiraiBaseModel +from typing import List +import pkg.utils.context as context +import pkg.utils.text2img as text2img + + +class ForwardMessageDiaplay(MiraiBaseModel): + title: str = "群聊的聊天记录" + brief: str = "[聊天记录]" + source: str = "聊天记录" + preview: List[str] = [] + summary: str = "查看x条转发消息" + + +class Forward(MessageComponent): + """合并转发。""" + type: str = "Forward" + """消息组件类型。""" + display: ForwardMessageDiaplay + """显示信息""" + node_list: List[ForwardMessageNode] + """转发消息节点列表。""" + def __init__(self, *args, **kwargs): + if len(args) == 1: + self.node_list = args[0] + super().__init__(**kwargs) + super().__init__(*args, **kwargs) + + def __str__(self): + return '[聊天记录]' + + +def text_to_image(text: str) -> MessageComponent: + """将文本转换成图片""" + # 检查temp文件夹是否存在 + if not os.path.exists('temp'): + os.mkdir('temp') + img_path = text2img.text_to_image(text_str=text, save_as='temp/{}.png'.format(int(time.time()))) + + compressed_path, size = text2img.compress_image(img_path, outfile="temp/{}_compressed.png".format(int(time.time()))) + # 读取图片,转换成base64 + with open(compressed_path, 'rb') as f: + img = f.read() + + b64 = base64.b64encode(img) + + # 删除图片 + os.remove(img_path) + + # 判断compressed_path是否存在 + if os.path.exists(compressed_path): + os.remove(compressed_path) + # 返回图片 + return Image(base64=b64.decode('utf-8')) + + +def check_text(text: str) -> list: + """检查文本是否为长消息,并转换成该使用的消息链组件""" + if not hasattr(config, 'blob_message_threshold'): + return [text] + + if len(text) > config.blob_message_threshold: + if not hasattr(config, 'blob_message_strategy'): + raise AttributeError('未定义长消息处理策略') + + # logging.info("长消息: {}".format(text)) + if config.blob_message_strategy == 'image': + # 转换成图片 + return [text_to_image(text)] + elif config.blob_message_strategy == 'forward': + # 敏感词屏蔽 + text = context.get_qqbot_manager().reply_filter.process(text) + + # 包装转发消息 + display = ForwardMessageDiaplay( + title='群聊的聊天记录', + brief='[聊天记录]', + source='聊天记录', + preview=["bot: "+text], + summary="查看1条转发消息" + ) + + node = ForwardMessageNode( + sender_id=config.mirai_http_api_config['qq'], + sender_name='bot', + message_chain=MessageChain([text]) + ) + + forward = Forward( + display=display, + node_list=[node] + ) + + return [forward] + else: + return [text] \ No newline at end of file diff --git a/pkg/qqbot/message.py b/pkg/qqbot/message.py index 864fda20..e6106df1 100644 --- a/pkg/qqbot/message.py +++ b/pkg/qqbot/message.py @@ -7,6 +7,7 @@ import pkg.openai.session import pkg.plugin.host as plugin_host import pkg.plugin.models as plugin_models +import pkg.qqbot.blob as blob def handle_exception(notify_admin: str = "", set_reply: str = "") -> list: @@ -63,7 +64,7 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str, reply = event.get_return_value("reply") if not event.is_prevented_default(): - reply = [prefix + text] + reply = blob.check_text(prefix + text) except openai.error.APIConnectionError as e: err_msg = str(e) if err_msg.__contains__('Error communicating with OpenAI'): diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index c1e39af3..0f34ed1b 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -153,7 +153,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes "..." if len(reply[0]) > 100 else ""))) reply = [mgr.reply_filter.process(reply[0])] else: - logging.info("回复[{}]图片消息:{}".format(session_name, reply)) + logging.info("回复[{}]消息".format(session_name)) finally: processing.remove(session_name) diff --git a/pkg/utils/text2img.py b/pkg/utils/text2img.py new file mode 100644 index 00000000..d5ee88b5 --- /dev/null +++ b/pkg/utils/text2img.py @@ -0,0 +1,164 @@ +from PIL import Image, ImageDraw, ImageFont +import re +import os + +text_render_font = ImageFont.truetype("res/simhei.ttf", 32, encoding="utf-8") + + +def indexNumber(path=''): + """ + 查找字符串中数字所在串中的位置 + :param path:目标字符串 + :return:: : [['1', 16], ['2', 35], ['1', 51]] + """ + kv = [] + nums = [] + beforeDatas = re.findall('[\d]+', path) + for num in beforeDatas: + indexV = [] + times = path.count(num) + if times > 1: + if num not in nums: + indexs = re.finditer(num, path) + for index in indexs: + iV = [] + i = index.span()[0] + iV.append(num) + iV.append(i) + kv.append(iV) + nums.append(num) + else: + index = path.find(num) + indexV.append(num) + indexV.append(index) + kv.append(indexV) + # 根据数字位置排序 + indexSort = [] + resultIndex = [] + for vi in kv: + indexSort.append(vi[1]) + indexSort.sort() + for i in indexSort: + for v in kv: + if i == v[1]: + resultIndex.append(v) + return resultIndex + + +def get_size(file): + # 获取文件大小:KB + size = os.path.getsize(file) + return size / 1024 + + +def get_outfile(infile, outfile): + if outfile: + return outfile + dir, suffix = os.path.splitext(infile) + outfile = '{}-out{}'.format(dir, suffix) + return outfile + + +def compress_image(infile, outfile='', kb=100, step=20, quality=90): + """不改变图片尺寸压缩到指定大小 + :param infile: 压缩源文件 + :param outfile: 压缩文件保存地址 + :param mb: 压缩目标,KB + :param step: 每次调整的压缩比率 + :param quality: 初始压缩比率 + :return: 压缩文件地址,压缩文件大小 + """ + o_size = get_size(infile) + if o_size <= kb: + return infile, o_size + outfile = get_outfile(infile, outfile) + while o_size > kb: + im = Image.open(infile) + im.save(outfile, quality=quality) + if quality - step < 0: + break + quality -= step + o_size = get_size(outfile) + return outfile, get_size(outfile) + + +def text_to_image(text_str: str, save_as="temp.png", width=800): + global text_render_font + + text_str = text_str.replace("\t", " ") + + # 分行 + lines = text_str.split('\n') + + # 计算并分割 + final_lines = [] + + text_width = width-80 + for line in lines: + # 如果长了就分割 + line_width = text_render_font.getlength(line) + if line_width < text_width: + final_lines.append(line) + continue + else: + rest_text = line + while True: + # 分割最前面的一行 + point = int(len(rest_text) * (text_width / line_width)) + + # 检查断点是否在数字中间 + numbers = indexNumber(rest_text) + + for number in numbers: + if number[1] < point < number[1] + len(number[0]) and number[1] != 0: + point = number[1] + break + + final_lines.append(rest_text[:point]) + rest_text = rest_text[point:] + line_width = text_render_font.getlength(rest_text) + if line_width < text_width: + final_lines.append(rest_text) + break + else: + continue + # 准备画布 + img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 45)), (255, 255, 255, 255)) + draw = ImageDraw.Draw(img, mode='RGBA') + + + # 绘制正文 + line_number = 0 + offset_x = 20 + offset_y = 30 + for final_line in final_lines: + draw.text((offset_x, offset_y + 35 * line_number), final_line, fill=(0, 0, 0), font=text_render_font) + # 遍历此行,检查是否有emoji + idx_in_line = 0 + for ch in final_line: + # if self.is_emoji(ch): + # emoji_img_valid = ensure_emoji(hex(ord(ch))[2:]) + # if emoji_img_valid: # emoji图像可用,绘制到指定位置 + # emoji_image = Image.open("emojis/{}.png".format(hex(ord(ch))[2:]), mode='r').convert('RGBA') + # emoji_image = emoji_image.resize((32, 32)) + + # x, y = emoji_image.size + + # final_emoji_img = Image.new('RGBA', emoji_image.size, (255, 255, 255)) + # final_emoji_img.paste(emoji_image, (0, 0, x, y), emoji_image) + + # img.paste(final_emoji_img, box=(int(offset_x + idx_in_line * 32), offset_y + 35 * line_number)) + + # 检查字符占位宽 + char_code = ord(ch) + if char_code >= 127: + idx_in_line += 1 + else: + idx_in_line += 0.5 + + line_number += 1 + + + img.save(save_as) + + return save_as diff --git a/requirements.txt b/requirements.txt index 838279ea..628307bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ colorlog~=6.6.0 yiri-mirai~=0.2.6.1 websockets~=10.4 urllib3~=1.26.10 -func_timeout~=4.3.5 \ No newline at end of file +func_timeout~=4.3.5 +Pillow \ No newline at end of file diff --git a/res/simhei.ttf b/res/simhei.ttf new file mode 100644 index 00000000..5bd4687e Binary files /dev/null and b/res/simhei.ttf differ