From 377d455ec1ebe020348ee1174f431afbcb1cc549 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sun, 3 Aug 2025 13:08:51 +0800 Subject: [PATCH] perf: ruff format & remove `stream` params in requester --- libs/dingtalk_api/api.py | 30 +- libs/wechatpad_api/__init__.py | 2 +- libs/wechatpad_api/api/chatroom.py | 6 +- libs/wechatpad_api/api/downloadpai.py | 25 +- libs/wechatpad_api/api/friend.py | 5 - libs/wechatpad_api/api/login.py | 60 +-- libs/wechatpad_api/api/message.py | 111 ++--- libs/wechatpad_api/util/http_util.py | 48 +- .../controller/groups/pipelines/webchat.py | 20 +- pkg/core/entities.py | 4 +- pkg/pipeline/cntfilter/cntfilter.py | 2 +- pkg/pipeline/process/handlers/chat.py | 5 +- pkg/pipeline/respback/respback.py | 4 - pkg/platform/adapter.py | 17 +- pkg/platform/sources/aiocqhttp.py | 230 ++++++---- pkg/platform/sources/dingtalk.py | 9 +- pkg/platform/sources/discord.py | 20 +- pkg/platform/sources/lark.py | 13 +- pkg/platform/sources/nakuru.py | 5 +- pkg/platform/sources/officialaccount.py | 4 +- pkg/platform/sources/qqbotpy.py | 2 +- pkg/platform/sources/qqofficial.py | 9 +- pkg/platform/sources/slack.py | 8 +- pkg/platform/sources/telegram.py | 3 +- pkg/platform/sources/webchat.py | 17 +- pkg/platform/sources/wechatpad.py | 431 +++++++----------- pkg/platform/sources/wecom.py | 6 +- pkg/platform/sources/wecomcs.py | 6 +- pkg/platform/types/message.py | 17 +- pkg/provider/entities.py | 4 +- pkg/provider/modelmgr/requester.py | 17 +- pkg/provider/modelmgr/requesters/chatcmpl.py | 78 ++-- .../modelmgr/requesters/giteeaichatcmpl.py | 99 ++-- .../modelmgr/requesters/modelscopechatcmpl.py | 96 ++-- .../modelmgr/requesters/ppiochatcmpl.py | 111 +++-- pkg/provider/runners/difysvapi.py | 4 +- pkg/provider/runners/localagent.py | 3 +- pkg/utils/image.py | 8 +- pkg/utils/importutil.py | 2 +- 39 files changed, 685 insertions(+), 856 deletions(-) diff --git a/libs/dingtalk_api/api.py b/libs/dingtalk_api/api.py index d1c7065f..3d483a3a 100644 --- a/libs/dingtalk_api/api.py +++ b/libs/dingtalk_api/api.py @@ -3,7 +3,6 @@ import json import time from typing import Callable import dingtalk_stream # type: ignore -from dingtalk_stream import AckMessage, ChatbotHandler, CallbackHandler, CallbackMessage, ChatbotMessage, AICardReplier from .EchoHandler import EchoTextHandler from .dingtalkevent import DingTalkEvent import httpx @@ -254,24 +253,23 @@ class DingTalkClient: await self.logger.error(f'failed to send proactive massage to group: {traceback.format_exc()}') raise Exception(f'failed to send proactive massage to group: {traceback.format_exc()}') - async def create_and_card(self, temp_card_id: str, incoming_message: dingtalk_stream.ChatbotMessage,quote_origin:bool=False): - content_key = "content" - card_data = {content_key: ""} + async def create_and_card( + self, temp_card_id: str, incoming_message: dingtalk_stream.ChatbotMessage, quote_origin: bool = False + ): + content_key = 'content' + card_data = {content_key: ''} - card_instance = dingtalk_stream.AICardReplier( - self.client, incoming_message - ) + card_instance = dingtalk_stream.AICardReplier(self.client, incoming_message) # print(card_instance) # 先投放卡片: https://open.dingtalk.com/document/orgapp/create-and-deliver-cards card_instance_id = await card_instance.async_create_and_deliver_card( - temp_card_id, card_data, + temp_card_id, + card_data, ) - return card_instance,card_instance_id + return card_instance, card_instance_id - async def send_card_message(self, - card_instance, - card_instance_id: str,content: str,is_final: bool): - content_key = "content" + async def send_card_message(self, card_instance, card_instance_id: str, content: str, is_final: bool): + content_key = 'content' try: await card_instance.async_streaming( card_instance_id, @@ -286,16 +284,12 @@ class DingTalkClient: await card_instance.async_streaming( card_instance_id, content_key=content_key, - content_value="", + content_value='', append=False, finished=is_final, failed=True, ) - - - - async def start(self): """启动 WebSocket 连接,监听消息""" await self.client.start() diff --git a/libs/wechatpad_api/__init__.py b/libs/wechatpad_api/__init__.py index 23c23fb2..9ac533f7 100644 --- a/libs/wechatpad_api/__init__.py +++ b/libs/wechatpad_api/__init__.py @@ -1 +1 @@ -from .client import WeChatPadClient \ No newline at end of file +from .client import WeChatPadClient as WeChatPadClient diff --git a/libs/wechatpad_api/api/chatroom.py b/libs/wechatpad_api/api/chatroom.py index a7af207c..2d9281a2 100644 --- a/libs/wechatpad_api/api/chatroom.py +++ b/libs/wechatpad_api/api/chatroom.py @@ -1,4 +1,4 @@ -from libs.wechatpad_api.util.http_util import async_request, post_json +from libs.wechatpad_api.util.http_util import post_json class ChatRoomApi: @@ -7,8 +7,6 @@ class ChatRoomApi: self.token = token def get_chatroom_member_detail(self, chatroom_name): - params = { - "ChatRoomName": chatroom_name - } + params = {'ChatRoomName': chatroom_name} url = self.base_url + '/group/GetChatroomMemberDetail' return post_json(url, token=self.token, data=params) diff --git a/libs/wechatpad_api/api/downloadpai.py b/libs/wechatpad_api/api/downloadpai.py index a82a5674..2d45fac6 100644 --- a/libs/wechatpad_api/api/downloadpai.py +++ b/libs/wechatpad_api/api/downloadpai.py @@ -1,32 +1,23 @@ -from libs.wechatpad_api.util.http_util import async_request, post_json +from libs.wechatpad_api.util.http_util import post_json import httpx import base64 + class DownloadApi: def __init__(self, base_url, token): self.base_url = base_url self.token = token def send_download(self, aeskey, file_type, file_url): - json_data = { - "AesKey": aeskey, - "FileType": file_type, - "FileURL": file_url - } - url = self.base_url + "/message/SendCdnDownload" + json_data = {'AesKey': aeskey, 'FileType': file_type, 'FileURL': file_url} + url = self.base_url + '/message/SendCdnDownload' return post_json(url, token=self.token, data=json_data) - def get_msg_voice(self,buf_id, length, new_msgid): - json_data = { - "Bufid": buf_id, - "Length": length, - "NewMsgId": new_msgid, - "ToUserName": "" - } - url = self.base_url + "/message/GetMsgVoice" + def get_msg_voice(self, buf_id, length, new_msgid): + json_data = {'Bufid': buf_id, 'Length': length, 'NewMsgId': new_msgid, 'ToUserName': ''} + url = self.base_url + '/message/GetMsgVoice' return post_json(url, token=self.token, data=json_data) - async def download_url_to_base64(self, download_url): async with httpx.AsyncClient() as client: response = await client.get(download_url) @@ -36,4 +27,4 @@ class DownloadApi: base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式 return base64_str else: - raise Exception('获取文件失败') \ No newline at end of file + raise Exception('获取文件失败') diff --git a/libs/wechatpad_api/api/friend.py b/libs/wechatpad_api/api/friend.py index 00701a5d..a7a448aa 100644 --- a/libs/wechatpad_api/api/friend.py +++ b/libs/wechatpad_api/api/friend.py @@ -1,11 +1,6 @@ -from libs.wechatpad_api.util.http_util import post_json,async_request -from typing import List, Dict, Any, Optional - - class FriendApi: """联系人API类,处理所有与联系人相关的操作""" def __init__(self, base_url: str, token: str): self.base_url = base_url self.token = token - diff --git a/libs/wechatpad_api/api/login.py b/libs/wechatpad_api/api/login.py index 142a3c85..4aa4ae8d 100644 --- a/libs/wechatpad_api/api/login.py +++ b/libs/wechatpad_api/api/login.py @@ -1,37 +1,34 @@ -from libs.wechatpad_api.util.http_util import async_request,post_json,get_json +from libs.wechatpad_api.util.http_util import post_json, get_json class LoginApi: def __init__(self, base_url: str, token: str = None, admin_key: str = None): - ''' + """ Args: base_url: 原始路径 token: token admin_key: 管理员key - ''' + """ self.base_url = base_url self.token = token # self.admin_key = admin_key - def get_token(self, admin_key, day: int=365): + def get_token(self, admin_key, day: int = 365): # 获取普通token - url = f"{self.base_url}/admin/GenAuthKey1" - json_data = { - "Count": 1, - "Days": day - } + url = f'{self.base_url}/admin/GenAuthKey1' + json_data = {'Count': 1, 'Days': day} return post_json(base_url=url, token=admin_key, data=json_data) - def get_login_qr(self, Proxy: str = ""): - ''' + def get_login_qr(self, Proxy: str = ''): + """ Args: Proxy:异地使用时代理 Returns:json数据 - ''' + """ """ { @@ -49,54 +46,37 @@ class LoginApi: } """ - #获取登录二维码 - url = f"{self.base_url}/login/GetLoginQrCodeNew" + # 获取登录二维码 + url = f'{self.base_url}/login/GetLoginQrCodeNew' check = False - if Proxy != "": + if Proxy != '': check = True - json_data = { - "Check": check, - "Proxy": Proxy - } + json_data = {'Check': check, 'Proxy': Proxy} return post_json(base_url=url, token=self.token, data=json_data) - def get_login_status(self): # 获取登录状态 url = f'{self.base_url}/login/GetLoginStatus' return get_json(base_url=url, token=self.token) - - def logout(self): # 退出登录 url = f'{self.base_url}/login/LogOut' return post_json(base_url=url, token=self.token) - - - - def wake_up_login(self, Proxy: str = ""): + def wake_up_login(self, Proxy: str = ''): # 唤醒登录 url = f'{self.base_url}/login/WakeUpLogin' check = False - if Proxy != "": + if Proxy != '': check = True - json_data = { - "Check": check, - "Proxy": "" - } + json_data = {'Check': check, 'Proxy': ''} return post_json(base_url=url, token=self.token, data=json_data) - - - def login(self,admin_key): + def login(self, admin_key): login_status = self.get_login_status() - if login_status["Code"] == 300 and login_status["Text"] == "你已退出微信": - print("token已经失效,重新获取") + if login_status['Code'] == 300 and login_status['Text'] == '你已退出微信': + print('token已经失效,重新获取') token_data = self.get_token(admin_key) - self.token = token_data["Data"][0] - - - + self.token = token_data['Data'][0] diff --git a/libs/wechatpad_api/api/message.py b/libs/wechatpad_api/api/message.py index 2089ce96..cca76313 100644 --- a/libs/wechatpad_api/api/message.py +++ b/libs/wechatpad_api/api/message.py @@ -1,5 +1,4 @@ - -from libs.wechatpad_api.util.http_util import async_request, post_json +from libs.wechatpad_api.util.http_util import post_json class MessageApi: @@ -7,8 +6,8 @@ class MessageApi: self.base_url = base_url self.token = token - def post_text(self, to_wxid, content, ats: list= []): - ''' + def post_text(self, to_wxid, content, ats: list = []): + """ Args: app_id: 微信id @@ -18,106 +17,64 @@ class MessageApi: Returns: - ''' - url = self.base_url + "/message/SendTextMessage" + """ + url = self.base_url + '/message/SendTextMessage' """发送文字消息""" json_data = { - "MsgItem": [ - { - "AtWxIDList": ats, - "ImageContent": "", - "MsgType": 0, - "TextContent": content, - "ToUserName": to_wxid - } - ] - } - return post_json(base_url=url, token=self.token, data=json_data) + 'MsgItem': [ + {'AtWxIDList': ats, 'ImageContent': '', 'MsgType': 0, 'TextContent': content, 'ToUserName': to_wxid} + ] + } + return post_json(base_url=url, token=self.token, data=json_data) - - - - def post_image(self, to_wxid, img_url, ats: list= []): + def post_image(self, to_wxid, img_url, ats: list = []): """发送图片消息""" # 这里好像可以尝试发送多个暂时未测试 json_data = { - "MsgItem": [ - { - "AtWxIDList": ats, - "ImageContent": img_url, - "MsgType": 0, - "TextContent": '', - "ToUserName": to_wxid - } + 'MsgItem': [ + {'AtWxIDList': ats, 'ImageContent': img_url, 'MsgType': 0, 'TextContent': '', 'ToUserName': to_wxid} ] } - url = self.base_url + "/message/SendImageMessage" + url = self.base_url + '/message/SendImageMessage' return post_json(base_url=url, token=self.token, data=json_data) def post_voice(self, to_wxid, voice_data, voice_forma, voice_duration): """发送语音消息""" json_data = { - "ToUserName": to_wxid, - "VoiceData": voice_data, - "VoiceFormat": voice_forma, - "VoiceSecond": voice_duration + 'ToUserName': to_wxid, + 'VoiceData': voice_data, + 'VoiceFormat': voice_forma, + 'VoiceSecond': voice_duration, } - url = self.base_url + "/message/SendVoice" + url = self.base_url + '/message/SendVoice' return post_json(base_url=url, token=self.token, data=json_data) - - - - def post_name_card(self, alias, to_wxid, nick_name, name_card_wxid, flag): """发送名片消息""" param = { - "CardAlias": alias, - "CardFlag": flag, - "CardNickName": nick_name, - "CardWxId": name_card_wxid, - "ToUserName": to_wxid + 'CardAlias': alias, + 'CardFlag': flag, + 'CardNickName': nick_name, + 'CardWxId': name_card_wxid, + 'ToUserName': to_wxid, } - url = f"{self.base_url}/message/ShareCardMessage" + url = f'{self.base_url}/message/ShareCardMessage' return post_json(base_url=url, token=self.token, data=param) - def post_emoji(self, to_wxid, emoji_md5, emoji_size:int=0): + def post_emoji(self, to_wxid, emoji_md5, emoji_size: int = 0): """发送emoji消息""" - json_data = { - "EmojiList": [ - { - "EmojiMd5": emoji_md5, - "EmojiSize": emoji_size, - "ToUserName": to_wxid - } - ] - } - url = f"{self.base_url}/message/SendEmojiMessage" + json_data = {'EmojiList': [{'EmojiMd5': emoji_md5, 'EmojiSize': emoji_size, 'ToUserName': to_wxid}]} + url = f'{self.base_url}/message/SendEmojiMessage' return post_json(base_url=url, token=self.token, data=json_data) - def post_app_msg(self, to_wxid,xml_data, contenttype:int=0): + def post_app_msg(self, to_wxid, xml_data, contenttype: int = 0): """发送appmsg消息""" - json_data = { - "AppList": [ - { - "ContentType": contenttype, - "ContentXML": xml_data, - "ToUserName": to_wxid - } - ] - } - url = f"{self.base_url}/message/SendAppMessage" + json_data = {'AppList': [{'ContentType': contenttype, 'ContentXML': xml_data, 'ToUserName': to_wxid}]} + url = f'{self.base_url}/message/SendAppMessage' return post_json(base_url=url, token=self.token, data=json_data) - - def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time): """撤回消息""" - param = { - "ClientMsgId": msg_id, - "CreateTime": create_time, - "NewMsgId": new_msg_id, - "ToUserName": to_wxid - } - url = f"{self.base_url}/message/RevokeMsg" - return post_json(base_url=url, token=self.token, data=param) \ No newline at end of file + param = {'ClientMsgId': msg_id, 'CreateTime': create_time, 'NewMsgId': new_msg_id, 'ToUserName': to_wxid} + url = f'{self.base_url}/message/RevokeMsg' + return post_json(base_url=url, token=self.token, data=param) diff --git a/libs/wechatpad_api/util/http_util.py b/libs/wechatpad_api/util/http_util.py index 754003e9..447c29df 100644 --- a/libs/wechatpad_api/util/http_util.py +++ b/libs/wechatpad_api/util/http_util.py @@ -1,10 +1,9 @@ import requests +import aiohttp + def post_json(base_url, token, data=None): - headers = { - 'Content-Type': 'application/json' - } - + headers = {'Content-Type': 'application/json'} url = base_url + f'?key={token}' @@ -18,14 +17,12 @@ def post_json(base_url, token, data=None): else: raise RuntimeError(response.text) except Exception as e: - print(f"http请求失败, url={url}, exception={e}") + print(f'http请求失败, url={url}, exception={e}') raise RuntimeError(str(e)) -def get_json(base_url, token): - headers = { - 'Content-Type': 'application/json' - } +def get_json(base_url, token): + headers = {'Content-Type': 'application/json'} url = base_url + f'?key={token}' @@ -39,21 +36,18 @@ def get_json(base_url, token): else: raise RuntimeError(response.text) except Exception as e: - print(f"http请求失败, url={url}, exception={e}") + print(f'http请求失败, url={url}, exception={e}') raise RuntimeError(str(e)) -import aiohttp -import asyncio - async def async_request( - base_url: str, - token_key: str, - method: str = 'POST', - params: dict = None, - # headers: dict = None, - data: dict = None, - json: dict = None + base_url: str, + token_key: str, + method: str = 'POST', + params: dict = None, + # headers: dict = None, + data: dict = None, + json: dict = None, ): """ 通用异步请求函数 @@ -67,18 +61,11 @@ async def async_request( :param json: JSON数据 :return: 响应文本 """ - headers = { - 'Content-Type': 'application/json' - } - url = f"{base_url}?key={token_key}" + headers = {'Content-Type': 'application/json'} + url = f'{base_url}?key={token_key}' async with aiohttp.ClientSession() as session: async with session.request( - method=method, - url=url, - params=params, - headers=headers, - data=data, - json=json + method=method, url=url, params=params, headers=headers, data=data, json=json ) as response: response.raise_for_status() # 如果状态码不是200,抛出异常 result = await response.json() @@ -89,4 +76,3 @@ async def async_request( # return await result # else: # raise RuntimeError("请求失败",response.text) - diff --git a/pkg/api/http/controller/groups/pipelines/webchat.py b/pkg/api/http/controller/groups/pipelines/webchat.py index f8698b01..62e5da3f 100644 --- a/pkg/api/http/controller/groups/pipelines/webchat.py +++ b/pkg/api/http/controller/groups/pipelines/webchat.py @@ -14,8 +14,9 @@ class WebChatDebugRouterGroup(group.RouterGroup): async def stream_generator(generator): async for message in generator: - yield f"data: {json.dumps({'message': message})}\n\n" - yield "data: {\"type\": \"end\"}\n\n" + yield f'data: {json.dumps({"message": message})}\n\n' + yield 'data: {"type": "end"}\n\n' + try: data = await quart.request.get_json() session_type = data.get('session_type', 'person') @@ -34,18 +35,18 @@ class WebChatDebugRouterGroup(group.RouterGroup): return self.http_status(404, -1, 'WebChat adapter not found') if is_stream: - - generator = webchat_adapter.send_webchat_message(pipeline_uuid, session_type, message_chain_obj, is_stream) - - return quart.Response( - stream_generator(generator), - mimetype='text/event-stream' + generator = webchat_adapter.send_webchat_message( + pipeline_uuid, session_type, message_chain_obj, is_stream ) + return quart.Response(stream_generator(generator), mimetype='text/event-stream') + else: # result = await webchat_adapter.send_webchat_message(pipeline_uuid, session_type, message_chain_obj) result = None - async for message in webchat_adapter.send_webchat_message(pipeline_uuid, session_type, message_chain_obj): + async for message in webchat_adapter.send_webchat_message( + pipeline_uuid, session_type, message_chain_obj + ): result = message if result is not None: return self.success( @@ -56,7 +57,6 @@ class WebChatDebugRouterGroup(group.RouterGroup): else: return self.http_status(400, -1, 'message is required') - except Exception as e: return self.http_status(500, -1, f'Internal server error: {str(e)}') diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 4873d9ce..1efee3fc 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -87,7 +87,9 @@ class Query(pydantic.BaseModel): """使用的函数,由前置处理器阶段设置""" resp_messages: ( - typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]] | typing.Optional[list[llm_entities.MessageChunk]] + typing.Optional[list[llm_entities.Message]] + | typing.Optional[list[platform_message.MessageChain]] + | typing.Optional[list[llm_entities.MessageChunk]] ) = [] """由Process阶段生成的回复消息对象列表""" diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index 0bbc5103..e035c1d0 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -67,7 +67,7 @@ class ContentFilterStage(stage.PipelineStage): if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg': return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) if not message.strip(): - return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) + return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) else: for filter in self.filter_chain: if filter_entities.EnableStage.PRE in filter.enable_stages: diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 483dd0b7..a81d8e3f 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -1,7 +1,6 @@ from __future__ import annotations import uuid -from itertools import accumulate import typing import traceback @@ -82,9 +81,7 @@ class ChatMessageHandler(handler.MessageHandler): query.resp_message_chain.pop() query.resp_messages.append(result) - self.ap.logger.info( - f'对话({query.query_id})流式响应: {self.cut_str(result.readable_str())}' - ) + self.ap.logger.info(f'对话({query.query_id})流式响应: {self.cut_str(result.readable_str())}') if result.content is not None: text_length += len(result.content) diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index f4153218..c7824856 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -3,12 +3,10 @@ from __future__ import annotations import random import asyncio -from typing_inspection.typing_objects import is_final from ...platform.types import events as platform_events from ...platform.types import message as platform_message -from ...provider import entities as llm_entities from .. import stage, entities from ...core import entities as core_entities @@ -56,6 +54,4 @@ class SendResponseBackStage(stage.PipelineStage): quote_origin=quote_origin, ) - - return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) diff --git a/pkg/platform/adapter.py b/pkg/platform/adapter.py index e4369efb..3412be3c 100644 --- a/pkg/platform/adapter.py +++ b/pkg/platform/adapter.py @@ -25,7 +25,6 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): logger: EventLogger - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): """初始化适配器 @@ -80,12 +79,12 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): """ raise NotImplementedError - async def create_message_card(self, message_id:typing.Type[str,int], event:platform_events.MessageEvent) -> bool: + async def create_message_card(self, message_id: typing.Type[str, int], event: platform_events.MessageEvent) -> bool: """创建卡片消息 Args: message_id (str): 消息ID event (platform_events.MessageEvent): 消息源事件 - """ + """ return False async def is_muted(self, group_id: int) -> bool: @@ -94,8 +93,8 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): def register_listener( self, - event_type: typing.Type[platform_message.Event], - callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None], + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], ): """注册事件监听器 @@ -107,8 +106,8 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): def unregister_listener( self, - event_type: typing.Type[platform_message.Event], - callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None], + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], ): """注销事件监听器 @@ -167,7 +166,7 @@ class EventConverter: """事件转换器基类""" @staticmethod - def yiri2target(event: typing.Type[platform_message.Event]): + def yiri2target(event: typing.Type[platform_events.Event]): """将源平台事件转换为目标平台事件 Args: @@ -179,7 +178,7 @@ class EventConverter: raise NotImplementedError @staticmethod - def target2yiri(event: typing.Any) -> platform_message.Event: + def target2yiri(event: typing.Any) -> platform_events.Event: """将目标平台事件的调用参数转换为源平台的事件参数对象 Args: diff --git a/pkg/platform/sources/aiocqhttp.py b/pkg/platform/sources/aiocqhttp.py index 3f3ef512..c75d2c77 100644 --- a/pkg/platform/sources/aiocqhttp.py +++ b/pkg/platform/sources/aiocqhttp.py @@ -16,7 +16,6 @@ from ..logger import EventLogger class AiocqhttpMessageConverter(adapter.MessageConverter): - @staticmethod async def yiri2target( message_chain: platform_message.MessageChain, @@ -62,87 +61,170 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): for node in msg.node_list: msg_list.extend((await AiocqhttpMessageConverter.yiri2target(node.message_chain))[0]) elif isinstance(msg, platform_message.File): - msg_list.append({"type":"file", "data":{'file': msg.url, "name": msg.name}}) + msg_list.append({'type': 'file', 'data': {'file': msg.url, 'name': msg.name}}) elif isinstance(msg, platform_message.Face): - if msg.face_type=='face': + if msg.face_type == 'face': msg_list.append(aiocqhttp.MessageSegment.face(msg.face_id)) - elif msg.face_type=='rps': + elif msg.face_type == 'rps': msg_list.append(aiocqhttp.MessageSegment.rps()) - elif msg.face_type=='dice': + elif msg.face_type == 'dice': msg_list.append(aiocqhttp.MessageSegment.dice()) - else: msg_list.append(aiocqhttp.MessageSegment.text(str(msg))) return msg_list, msg_id, msg_time @staticmethod - async def target2yiri(message: str, message_id: int = -1,bot=None): + async def target2yiri(message: str, message_id: int = -1, bot=None): print(message) message = aiocqhttp.Message(message) def get_face_name(face_id): face_code_dict = { - "2": '好色', - "4": "得意", "5": "流泪", "8": "睡", "9": "大哭", "10": "尴尬", "12": "调皮", "14": "微笑", "16": "酷", - "21": "可爱", - "23": "傲慢", "24": "饥饿", "25": "困", "26": "惊恐", "27": "流汗", "28": "憨笑", "29": "悠闲", - "30": "奋斗", - "32": "疑问", "33": "嘘", "34": "晕", "38": "敲打", "39": "再见", "41": "发抖", "42": "爱情", - "43": "跳跳", - "49": "拥抱", "53": "蛋糕", "60": "咖啡", "63": "玫瑰", "66": "爱心", "74": "太阳", "75": "月亮", - "76": "赞", - "78": "握手", "79": "胜利", "85": "飞吻", "89": "西瓜", "96": "冷汗", "97": "擦汗", "98": "抠鼻", - "99": "鼓掌", - "100": "糗大了", "101": "坏笑", "102": "左哼哼", "103": "右哼哼", "104": "哈欠", "106": "委屈", - "109": "左亲亲", - "111": "可怜", "116": "示爱", "118": "抱拳", "120": "拳头", "122": "爱你", "123": "NO", "124": "OK", - "125": "转圈", - "129": "挥手", "144": "喝彩", "147": "棒棒糖", "171": "茶", "173": "泪奔", "174": "无奈", "175": "卖萌", - "176": "小纠结", "179": "doge", "180": "惊喜", "181": "骚扰", "182": "笑哭", "183": "我最美", - "201": "点赞", - "203": "托脸", "212": "托腮", "214": "啵啵", "219": "蹭一蹭", "222": "抱抱", "227": "拍手", - "232": "佛系", - "240": "喷脸", "243": "甩头", "246": "加油抱抱", "262": "脑阔疼", "264": "捂脸", "265": "辣眼睛", - "266": "哦哟", - "267": "头秃", "268": "问号脸", "269": "暗中观察", "270": "emm", "271": "吃瓜", "272": "呵呵哒", - "273": "我酸了", - "277": "汪汪", "278": "汗", "281": "无眼笑", "282": "敬礼", "284": "面无表情", "285": "摸鱼", - "287": "哦", - "289": "睁眼", "290": "敲开心", "293": "摸锦鲤", "294": "期待", "297": "拜谢", "298": "元宝", - "299": "牛啊", - "305": "右亲亲", "306": "牛气冲天", "307": "喵喵", "314": "仔细分析", "315": "加油", "318": "崇拜", - "319": "比心", - "320": "庆祝", "322": "拒绝", "324": "吃糖", "326": "生气" + '2': '好色', + '4': '得意', + '5': '流泪', + '8': '睡', + '9': '大哭', + '10': '尴尬', + '12': '调皮', + '14': '微笑', + '16': '酷', + '21': '可爱', + '23': '傲慢', + '24': '饥饿', + '25': '困', + '26': '惊恐', + '27': '流汗', + '28': '憨笑', + '29': '悠闲', + '30': '奋斗', + '32': '疑问', + '33': '嘘', + '34': '晕', + '38': '敲打', + '39': '再见', + '41': '发抖', + '42': '爱情', + '43': '跳跳', + '49': '拥抱', + '53': '蛋糕', + '60': '咖啡', + '63': '玫瑰', + '66': '爱心', + '74': '太阳', + '75': '月亮', + '76': '赞', + '78': '握手', + '79': '胜利', + '85': '飞吻', + '89': '西瓜', + '96': '冷汗', + '97': '擦汗', + '98': '抠鼻', + '99': '鼓掌', + '100': '糗大了', + '101': '坏笑', + '102': '左哼哼', + '103': '右哼哼', + '104': '哈欠', + '106': '委屈', + '109': '左亲亲', + '111': '可怜', + '116': '示爱', + '118': '抱拳', + '120': '拳头', + '122': '爱你', + '123': 'NO', + '124': 'OK', + '125': '转圈', + '129': '挥手', + '144': '喝彩', + '147': '棒棒糖', + '171': '茶', + '173': '泪奔', + '174': '无奈', + '175': '卖萌', + '176': '小纠结', + '179': 'doge', + '180': '惊喜', + '181': '骚扰', + '182': '笑哭', + '183': '我最美', + '201': '点赞', + '203': '托脸', + '212': '托腮', + '214': '啵啵', + '219': '蹭一蹭', + '222': '抱抱', + '227': '拍手', + '232': '佛系', + '240': '喷脸', + '243': '甩头', + '246': '加油抱抱', + '262': '脑阔疼', + '264': '捂脸', + '265': '辣眼睛', + '266': '哦哟', + '267': '头秃', + '268': '问号脸', + '269': '暗中观察', + '270': 'emm', + '271': '吃瓜', + '272': '呵呵哒', + '273': '我酸了', + '277': '汪汪', + '278': '汗', + '281': '无眼笑', + '282': '敬礼', + '284': '面无表情', + '285': '摸鱼', + '287': '哦', + '289': '睁眼', + '290': '敲开心', + '293': '摸锦鲤', + '294': '期待', + '297': '拜谢', + '298': '元宝', + '299': '牛啊', + '305': '右亲亲', + '306': '牛气冲天', + '307': '喵喵', + '314': '仔细分析', + '315': '加油', + '318': '崇拜', + '319': '比心', + '320': '庆祝', + '322': '拒绝', + '324': '吃糖', + '326': '生气', } - return face_code_dict.get(face_id,'') + return face_code_dict.get(face_id, '') async def process_message_data(msg_data, reply_list): - if msg_data["type"] == "image": - image_base64, image_format = await image.qq_image_url_to_base64(msg_data["data"]['url']) - reply_list.append( - platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}')) + if msg_data['type'] == 'image': + image_base64, image_format = await image.qq_image_url_to_base64(msg_data['data']['url']) + reply_list.append(platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}')) - elif msg_data["type"] == "text": - reply_list.append(platform_message.Plain(text=msg_data["data"]["text"])) + elif msg_data['type'] == 'text': + reply_list.append(platform_message.Plain(text=msg_data['data']['text'])) - elif msg_data["type"] == "forward": # 这里来应该传入转发消息组,暂时传入qoute - for forward_msg_datas in msg_data["data"]["content"]: - for forward_msg_data in forward_msg_datas["message"]: + elif msg_data['type'] == 'forward': # 这里来应该传入转发消息组,暂时传入qoute + for forward_msg_datas in msg_data['data']['content']: + for forward_msg_data in forward_msg_datas['message']: await process_message_data(forward_msg_data, reply_list) - elif msg_data["type"] == "at": - if msg_data["data"]['qq'] == 'all': + elif msg_data['type'] == 'at': + if msg_data['data']['qq'] == 'all': reply_list.append(platform_message.AtAll()) else: reply_list.append( platform_message.At( - target=msg_data["data"]['qq'], + target=msg_data['data']['qq'], ) ) - yiri_msg_list = [] yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now())) @@ -161,10 +243,10 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): elif msg.type == 'text': yiri_msg_list.append(platform_message.Plain(text=msg.data['text'])) elif msg.type == 'image': - emoji_id = msg.data.get("emoji_package_id", None) + emoji_id = msg.data.get('emoji_package_id', None) if emoji_id: face_id = emoji_id - face_name = msg.data.get("summary", '') + face_name = msg.data.get('summary', '') image_msg = platform_message.Face(face_id=face_id, face_name=face_name) else: image_base64, image_format = await image.qq_image_url_to_base64(msg.data['url']) @@ -178,14 +260,15 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): # await process_message_data(msg_data, yiri_msg_list) pass - elif msg.type == 'reply': # 此处处理引用消息传入Qoute - msg_datas = await bot.get_msg(message_id=msg.data["id"]) + msg_datas = await bot.get_msg(message_id=msg.data['id']) - for msg_data in msg_datas["message"]: + for msg_data in msg_datas['message']: await process_message_data(msg_data, reply_list) - reply_msg = platform_message.Quote(message_id=msg.data["id"],sender_id=msg_datas["user_id"],origin=reply_list) + reply_msg = platform_message.Quote( + message_id=msg.data['id'], sender_id=msg_datas['user_id'], origin=reply_list + ) yiri_msg_list.append(reply_msg) elif msg.type == 'file': @@ -193,50 +276,36 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): file_id = msg.data['file_id'] file_data = await bot.get_file(file_id=file_id) file_name = file_data.get('file_name') - file_path = file_data.get('file') + # file_path = file_data.get('file') file_url = file_data.get('file_url') file_size = file_data.get('file_size') - yiri_msg_list.append(platform_message.File(id=file_id, name=file_name,url=file_url,size=file_size)) + yiri_msg_list.append(platform_message.File(id=file_id, name=file_name, url=file_url, size=file_size)) elif msg.type == 'face': face_id = msg.data['id'] face_name = msg.data['raw']['faceText'] if not face_name: face_name = get_face_name(face_id) - yiri_msg_list.append(platform_message.Face(face_id=int(face_id),face_name=face_name.replace('/',''))) + yiri_msg_list.append(platform_message.Face(face_id=int(face_id), face_name=face_name.replace('/', ''))) elif msg.type == 'rps': face_id = msg.data['result'] - yiri_msg_list.append(platform_message.Face(face_type="rps",face_id=int(face_id),face_name='猜拳')) + yiri_msg_list.append(platform_message.Face(face_type='rps', face_id=int(face_id), face_name='猜拳')) elif msg.type == 'dice': face_id = msg.data['result'] - yiri_msg_list.append(platform_message.Face(face_type='dice',face_id=int(face_id),face_name='骰子')) - - - - - - - - + yiri_msg_list.append(platform_message.Face(face_type='dice', face_id=int(face_id), face_name='骰子')) chain = platform_message.MessageChain(yiri_msg_list) return chain - - - - class AiocqhttpEventConverter(adapter.EventConverter): @staticmethod async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int): return event.source_platform_object @staticmethod - async def target2yiri(event: aiocqhttp.Event,bot=None): - yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id,bot) - - + async def target2yiri(event: aiocqhttp.Event, bot=None): + yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id, bot) if event.message_type == 'group': permission = 'MEMBER' @@ -316,7 +385,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0] if target_type == 'group': - await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg) elif target_type == 'person': await self.bot.send_private_msg(user_id=int(target_id), message=aiocq_msg) @@ -345,7 +413,7 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): async def on_message(event: aiocqhttp.Event): self.bot_account_id = event.self_id try: - return await callback(await self.event_converter.target2yiri(event,self.bot), self) + return await callback(await self.event_converter.target2yiri(event, self.bot), self) except Exception: await self.logger.error(f'Error in on_message: {traceback.format_exc()}') traceback.print_exc() diff --git a/pkg/platform/sources/dingtalk.py b/pkg/platform/sources/dingtalk.py index 187bafb0..8bd6e187 100644 --- a/pkg/platform/sources/dingtalk.py +++ b/pkg/platform/sources/dingtalk.py @@ -149,10 +149,10 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): quote_origin: bool = False, is_final: bool = False, ): - event = await DingTalkEventConverter.yiri2target( - message_source, - ) - incoming_message = event.incoming_message + # event = await DingTalkEventConverter.yiri2target( + # message_source, + # ) + # incoming_message = event.incoming_message # msg_id = incoming_message.message_id @@ -205,7 +205,6 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): self.bot.on_message('GroupMessage')(on_message) async def run_async(self): - await self.bot.start() async def kill(self) -> bool: diff --git a/pkg/platform/sources/discord.py b/pkg/platform/sources/discord.py index 4f5cac28..6cc09a72 100644 --- a/pkg/platform/sources/discord.py +++ b/pkg/platform/sources/discord.py @@ -8,7 +8,6 @@ import base64 import uuid import os import datetime -import io import aiohttp @@ -78,10 +77,10 @@ class DiscordMessageConverter(adapter.MessageConverter): # 确保路径没有空字节 clean_path = ele.path.replace('\x00', '') clean_path = os.path.abspath(clean_path) - + if not os.path.exists(clean_path): continue # 跳过不存在的文件 - + try: with open(clean_path, 'rb') as f: image_bytes = f.read() @@ -101,12 +100,13 @@ class DiscordMessageConverter(adapter.MessageConverter): filename = f'{uuid.uuid4()}.webp' # 默认保持PNG except Exception as e: - print(f"Error reading image file {clean_path}: {e}") + print(f'Error reading image file {clean_path}: {e}') continue # 跳过读取失败的文件 if image_bytes: # 使用BytesIO创建文件对象,避免路径问题 import io + image_files.append(discord.File(fp=io.BytesIO(image_bytes), filename=filename)) elif isinstance(ele, platform_message.Plain): text_string += ele.text @@ -261,25 +261,25 @@ class DiscordAdapter(adapter.MessagePlatformAdapter): async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): msg_to_send, image_files = await self.message_converter.yiri2target(message) - + try: # 获取频道对象 channel = self.bot.get_channel(int(target_id)) if channel is None: # 如果本地缓存中没有,尝试从API获取 channel = await self.bot.fetch_channel(int(target_id)) - + args = { 'content': msg_to_send, } - + if len(image_files) > 0: args['files'] = image_files - + await channel.send(**args) - + except Exception as e: - await self.logger.error(f"Discord send_message failed: {e}") + await self.logger.error(f'Discord send_message failed: {e}') raise e async def reply_message( diff --git a/pkg/platform/sources/lark.py b/pkg/platform/sources/lark.py index dcafbf9f..5369be00 100644 --- a/pkg/platform/sources/lark.py +++ b/pkg/platform/sources/lark.py @@ -9,7 +9,6 @@ import re import base64 import uuid import json -import time import datetime import hashlib from Crypto.Cipher import AES @@ -394,14 +393,14 @@ class LarkAdapter(adapter.MessagePlatformAdapter): if 'im.message.receive_v1' == type: try: event = await self.event_converter.target2yiri(p2v1, self.api_client) - except Exception as e: + except Exception: await self.logger.error(f'Error in lark callback: {traceback.format_exc()}') if event.__class__ in self.listeners: await self.listeners[event.__class__](event, self) return {'code': 200, 'message': 'ok'} - except Exception as e: + except Exception: await self.logger.error(f'Error in lark callback: {traceback.format_exc()}') return {'code': 500, 'message': 'error'} @@ -559,10 +558,10 @@ class LarkAdapter(adapter.MessagePlatformAdapter): elif ele['tag'] == 'md': text_message += ele['text'] - content = { - 'type': 'card_json', - 'data': {'card_id': self.card_id_dict[message_id], 'elements': {'content': text_message}}, - } + # content = { + # 'type': 'card_json', + # 'data': {'card_id': self.card_id_dict[message_id], 'elements': {'content': text_message}}, + # } request: ContentCardElementRequest = ( ContentCardElementRequest.builder() diff --git a/pkg/platform/sources/nakuru.py b/pkg/platform/sources/nakuru.py index 389a2db1..16ad54db 100644 --- a/pkg/platform/sources/nakuru.py +++ b/pkg/platform/sources/nakuru.py @@ -72,8 +72,9 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter): content=content_list, ) nakuru_forward_node_list.append(nakuru_forward_node) - except Exception as e: + except Exception: import traceback + traceback.print_exc() nakuru_msg_list.append(nakuru_forward_node_list) @@ -276,7 +277,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): # 注册监听器 self.bot.receiver(source_cls.__name__)(listener_wrapper) except Exception as e: - self.logger.error(f"Error in nakuru register_listener: {traceback.format_exc()}") + self.logger.error(f'Error in nakuru register_listener: {traceback.format_exc()}') raise e def unregister_listener( diff --git a/pkg/platform/sources/officialaccount.py b/pkg/platform/sources/officialaccount.py index 030db56d..3fc1e393 100644 --- a/pkg/platform/sources/officialaccount.py +++ b/pkg/platform/sources/officialaccount.py @@ -125,8 +125,8 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = event.receiver_id try: return await callback(await self.event_converter.target2yiri(event), self) - except Exception as e: - await self.logger.error(f"Error in officialaccount callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in officialaccount callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('text')(on_message) diff --git a/pkg/platform/sources/qqbotpy.py b/pkg/platform/sources/qqbotpy.py index 39c8dc8a..d4a4d526 100644 --- a/pkg/platform/sources/qqbotpy.py +++ b/pkg/platform/sources/qqbotpy.py @@ -501,7 +501,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): for event_handler in event_handler_mapping[event_type]: setattr(self.bot, event_handler, wrapper) except Exception as e: - self.logger.error(f"Error in qqbotpy callback: {traceback.format_exc()}") + self.logger.error(f'Error in qqbotpy callback: {traceback.format_exc()}') raise e def unregister_listener( diff --git a/pkg/platform/sources/qqofficial.py b/pkg/platform/sources/qqofficial.py index c61afea4..63ab531f 100644 --- a/pkg/platform/sources/qqofficial.py +++ b/pkg/platform/sources/qqofficial.py @@ -154,10 +154,7 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter): raise ParamNotEnoughError('QQ官方机器人缺少相关配置项,请查看文档或联系管理员') self.bot = QQOfficialClient( - app_id=config['appid'], - secret=config['secret'], - token=config['token'], - logger=self.logger + app_id=config['appid'], secret=config['secret'], token=config['token'], logger=self.logger ) async def reply_message( @@ -224,8 +221,8 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = 'justbot' try: return await callback(await self.event_converter.target2yiri(event), self) - except Exception as e: - await self.logger.error(f"Error in qqofficial callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in qqofficial callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('DIRECT_MESSAGE_CREATE')(on_message) diff --git a/pkg/platform/sources/slack.py b/pkg/platform/sources/slack.py index 6dfcff59..1bd5aa2d 100644 --- a/pkg/platform/sources/slack.py +++ b/pkg/platform/sources/slack.py @@ -104,7 +104,9 @@ class SlackAdapter(adapter.MessagePlatformAdapter): if missing_keys: raise ParamNotEnoughError('Slack机器人缺少相关配置项,请查看文档或联系管理员') - self.bot = SlackClient(bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger) + self.bot = SlackClient( + bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger + ) async def reply_message( self, @@ -139,8 +141,8 @@ class SlackAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = 'SlackBot' try: return await callback(await self.event_converter.target2yiri(event, self.bot), self) - except Exception as e: - await self.logger.error(f"Error in slack callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in slack callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('im')(on_message) diff --git a/pkg/platform/sources/telegram.py b/pkg/platform/sources/telegram.py index e021c7b7..d39bf23d 100644 --- a/pkg/platform/sources/telegram.py +++ b/pkg/platform/sources/telegram.py @@ -1,6 +1,5 @@ from __future__ import annotations -import time import telegram import telegram.ext @@ -166,7 +165,7 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): lb_event = await self.event_converter.target2yiri(update, self.bot, self.bot_account_id) await self.listeners[type(lb_event)](lb_event, self) await self.is_stream_output_supported() - except Exception as e: + except Exception: await self.logger.error(f'Error in telegram callback: {traceback.format_exc()}') self.application = ApplicationBuilder().token(self.config['token']).build() diff --git a/pkg/platform/sources/webchat.py b/pkg/platform/sources/webchat.py index f7f3d964..fce28bc2 100644 --- a/pkg/platform/sources/webchat.py +++ b/pkg/platform/sources/webchat.py @@ -133,7 +133,11 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): ) # notify waiter - session = (self.webchat_group_session if isinstance(message_source, platform_events.GroupMessage) else self.webchat_person_session) + session = ( + self.webchat_group_session + if isinstance(message_source, platform_events.GroupMessage) + else self.webchat_person_session + ) if message_source.message_chain.message_id not in session.resp_waiters: # session.resp_waiters[message_source.message_chain.message_id] = asyncio.Queue() queue = session.resp_queues[message_source.message_chain.message_id] @@ -147,10 +151,8 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): # print(message_data) await queue.put(message_data) - - return message_data.model_dump() - + async def is_stream_output_supported(self) -> bool: return self.is_stream @@ -186,7 +188,10 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): await self.logger.info('WebChat调试适配器正在停止') async def send_webchat_message( - self, pipeline_uuid: str, session_type: str, message_chain_obj: typing.List[dict], + self, + pipeline_uuid: str, + session_type: str, + message_chain_obj: typing.List[dict], is_stream: bool = False, ) -> dict: self.is_stream = is_stream @@ -202,7 +207,7 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): if is_stream: use_session.resp_queues[message_id] = asyncio.Queue() - logger.debug(f"Initialized queue for message_id: {message_id}") + logger.debug(f'Initialized queue for message_id: {message_id}') use_session.get_message_list(pipeline_uuid).append( WebChatMessage( diff --git a/pkg/platform/sources/wechatpad.py b/pkg/platform/sources/wechatpad.py index fdd4a69b..5d8ec75d 100644 --- a/pkg/platform/sources/wechatpad.py +++ b/pkg/platform/sources/wechatpad.py @@ -1,5 +1,4 @@ import requests -import websockets import websocket import json import time @@ -10,53 +9,41 @@ from libs.wechatpad_api.client import WeChatPadClient import typing import asyncio import traceback -import time import re import base64 -import uuid -import json -import os import copy -import datetime import threading import quart -import aiohttp from .. import adapter -from ...pipeline.longtext.strategies import forward from ...core import app from ..types import message as platform_message from ..types import events as platform_events from ..types import entities as platform_entities -from ...utils import image from ..logger import EventLogger import xml.etree.ElementTree as ET -from typing import Optional, List, Tuple +from typing import Optional, Tuple from functools import partial import logging -class WeChatPadMessageConverter(adapter.MessageConverter): +class WeChatPadMessageConverter(adapter.MessageConverter): def __init__(self, config: dict): self.config = config - self.bot = WeChatPadClient(self.config["wechatpad_url"],self.config["token"]) - self.logger = logging.getLogger("WeChatPadMessageConverter") + self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token']) + self.logger = logging.getLogger('WeChatPadMessageConverter') @staticmethod - async def yiri2target( - message_chain: platform_message.MessageChain - ) -> list[dict]: + async def yiri2target(message_chain: platform_message.MessageChain) -> list[dict]: content_list = [] - current_file_path = os.path.abspath(__file__) - - + # current_file_path = os.path.abspath(__file__) for component in message_chain: if isinstance(component, platform_message.At): - content_list.append({"type": "at", "target": component.target}) + content_list.append({'type': 'at', 'target': component.target}) elif isinstance(component, platform_message.Plain): - content_list.append({"type": "text", "content": component.text}) + content_list.append({'type': 'text', 'content': component.text}) elif isinstance(component, platform_message.Image): if component.url: async with httpx.AsyncClient() as client: @@ -68,15 +55,16 @@ class WeChatPadMessageConverter(adapter.MessageConverter): else: raise Exception('获取文件失败') # pass - content_list.append({"type": "image", "image": base64_str}) + content_list.append({'type': 'image', 'image': base64_str}) elif component.base64: - content_list.append({"type": "image", "image": component.base64}) + content_list.append({'type': 'image', 'image': component.base64}) elif isinstance(component, platform_message.WeChatEmoji): content_list.append( - {'type': 'WeChatEmoji', 'emoji_md5': component.emoji_md5, 'emoji_size': component.emoji_size}) + {'type': 'WeChatEmoji', 'emoji_md5': component.emoji_md5, 'emoji_size': component.emoji_size} + ) elif isinstance(component, platform_message.Voice): - content_list.append({"type": "voice", "data": component.url, "duration": component.length, "forma": 0}) + content_list.append({'type': 'voice', 'data': component.url, 'duration': component.length, 'forma': 0}) elif isinstance(component, platform_message.WeChatAppMsg): content_list.append({'type': 'WeChatAppMsg', 'app_msg': component.app_msg}) elif isinstance(component, platform_message.Forward): @@ -86,28 +74,23 @@ class WeChatPadMessageConverter(adapter.MessageConverter): return content_list - - async def target2yiri( - self, - message: dict, - bot_account_id: str - ) -> platform_message.MessageChain: + async def target2yiri(self, message: dict, bot_account_id: str) -> platform_message.MessageChain: """外部消息转平台消息""" # 数据预处理 message_list = [] ats_bot = False # 是否被@ - content = message["content"]["str"] + content = message['content']['str'] content_no_preifx = content # 群消息则去掉前缀 is_group_message = self._is_group_message(message) if is_group_message: ats_bot = self._ats_bot(message, bot_account_id) - if "@所有人" in content: + if '@所有人' in content: message_list.append(platform_message.AtAll()) elif ats_bot: message_list.append(platform_message.At(target=bot_account_id)) content_no_preifx, _ = self._extract_content_and_sender(content) - msg_type = message["msg_type"] + msg_type = message['msg_type'] # 映射消息类型到处理器方法 handler_map = { @@ -129,11 +112,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter): return platform_message.MessageChain(message_list) - async def _handler_text( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_text(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理文本消息 (msg_type=1)""" if message and self._is_group_message(message): pattern = r'@\S{1,20}' @@ -141,16 +120,12 @@ class WeChatPadMessageConverter(adapter.MessageConverter): return platform_message.MessageChain([platform_message.Plain(content_no_preifx)]) - async def _handler_image( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_image(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理图像消息 (msg_type=3)""" try: image_xml = content_no_preifx if not image_xml: - return platform_message.MessageChain([platform_message.Unknown("[图片内容为空]")]) + return platform_message.MessageChain([platform_message.Unknown('[图片内容为空]')]) root = ET.fromstring(image_xml) # 提取img标签的属性 @@ -160,28 +135,22 @@ class WeChatPadMessageConverter(adapter.MessageConverter): cdnthumburl = img_tag.get('cdnthumburl') # cdnmidimgurl = img_tag.get('cdnmidimgurl') - image_data = self.bot.cdn_download(aeskey=aeskey, file_type=1, file_url=cdnthumburl) - if image_data["Data"]['FileData'] == '': + if image_data['Data']['FileData'] == '': image_data = self.bot.cdn_download(aeskey=aeskey, file_type=2, file_url=cdnthumburl) - base64_str = image_data["Data"]['FileData'] + base64_str = image_data['Data']['FileData'] # self.logger.info(f"data:image/png;base64,{base64_str}") - elements = [ - platform_message.Image(base64=f"data:image/png;base64,{base64_str}"), + platform_message.Image(base64=f'data:image/png;base64,{base64_str}'), # platform_message.WeChatForwardImage(xml_data=image_xml) # 微信消息转发 ] return platform_message.MessageChain(elements) except Exception as e: - self.logger.error(f"处理图片失败: {str(e)}") - return platform_message.MessageChain([platform_message.Unknown("[图片处理失败]")]) + self.logger.error(f'处理图片失败: {str(e)}') + return platform_message.MessageChain([platform_message.Unknown('[图片处理失败]')]) - async def _handler_voice( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_voice(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理语音消息 (msg_type=34)""" message_List = [] try: @@ -197,39 +166,33 @@ class WeChatPadMessageConverter(adapter.MessageConverter): bufid = voicemsg.get('bufid') length = voicemsg.get('voicelength') voice_data = self.bot.get_msg_voice(buf_id=str(bufid), length=int(length), msgid=str(new_msg_id)) - audio_base64 = voice_data["Data"]['Base64'] + audio_base64 = voice_data['Data']['Base64'] # 验证语音数据有效性 if not audio_base64: - message_List.append(platform_message.Unknown(text="[语音内容为空]")) + message_List.append(platform_message.Unknown(text='[语音内容为空]')) return platform_message.MessageChain(message_List) # 转换为平台支持的语音格式(如 Silk 格式) - voice_element = platform_message.Voice( - base64=f"data:audio/silk;base64,{audio_base64}" - ) + voice_element = platform_message.Voice(base64=f'data:audio/silk;base64,{audio_base64}') message_List.append(voice_element) except KeyError as e: - self.logger.error(f"语音数据字段缺失: {str(e)}") - message_List.append(platform_message.Unknown(text="[语音数据解析失败]")) + self.logger.error(f'语音数据字段缺失: {str(e)}') + message_List.append(platform_message.Unknown(text='[语音数据解析失败]')) except Exception as e: - self.logger.error(f"处理语音消息异常: {str(e)}") - message_List.append(platform_message.Unknown(text="[语音处理失败]")) + self.logger.error(f'处理语音消息异常: {str(e)}') + message_List.append(platform_message.Unknown(text='[语音处理失败]')) return platform_message.MessageChain(message_List) - async def _handler_compound( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_compound(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理复合消息 (msg_type=49),根据子类型分派""" try: xml_data = ET.fromstring(content_no_preifx) appmsg_data = xml_data.find('.//appmsg') if appmsg_data: - data_type = appmsg_data.findtext('.//type', "") + data_type = appmsg_data.findtext('.//type', '') # 二次分派处理器 sub_handler_map = { '57': self._handler_compound_quote, @@ -238,9 +201,9 @@ class WeChatPadMessageConverter(adapter.MessageConverter): '74': self._handler_compound_file, '33': self._handler_compound_mini_program, '36': self._handler_compound_mini_program, - '2000': partial(self._handler_compound_unsupported, text="[转账消息]"), - '2001': partial(self._handler_compound_unsupported, text="[红包消息]"), - '51': partial(self._handler_compound_unsupported, text="[视频号消息]"), + '2000': partial(self._handler_compound_unsupported, text='[转账消息]'), + '2001': partial(self._handler_compound_unsupported, text='[红包消息]'), + '51': partial(self._handler_compound_unsupported, text='[视频号消息]'), } handler = sub_handler_map.get(data_type, self._handler_compound_unsupported) @@ -251,56 +214,51 @@ class WeChatPadMessageConverter(adapter.MessageConverter): else: return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)]) except Exception as e: - self.logger.error(f"解析复合消息失败: {str(e)}") + self.logger.error(f'解析复合消息失败: {str(e)}') return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)]) async def _handler_compound_quote( - self, - message: Optional[dict], - xml_data: ET.Element + self, message: Optional[dict], xml_data: ET.Element ) -> platform_message.MessageChain: """处理引用消息 (data_type=57)""" message_list = [] -# self.logger.info("_handler_compound_quote", ET.tostring(xml_data, encoding='unicode')) + # self.logger.info("_handler_compound_quote", ET.tostring(xml_data, encoding='unicode')) appmsg_data = xml_data.find('.//appmsg') - quote_data = "" # 引用原文 - quote_id = None # 引用消息的原发送者 - tousername = None # 接收方: 所属微信的wxid - user_data = "" # 用户消息 + quote_data = '' # 引用原文 + # quote_id = None # 引用消息的原发送者 + # tousername = None # 接收方: 所属微信的wxid + user_data = '' # 用户消息 sender_id = xml_data.findtext('.//fromusername') # 发送方:单聊用户/群member # 引用消息转发 if appmsg_data: - user_data = appmsg_data.findtext('.//title') or "" + user_data = appmsg_data.findtext('.//title') or '' quote_data = appmsg_data.find('.//refermsg').findtext('.//content') - quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') - message_list.append( - platform_message.WeChatAppMsg( - app_msg=ET.tostring(appmsg_data, encoding='unicode')) - ) - if message: - tousername = message['to_user_name']["str"] - + # quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') + message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg_data, encoding='unicode'))) + # if message: + # tousername = message['to_user_name']['str'] + if quote_data: quote_data_message_list = platform_message.MessageChain() # 文本消息 try: - if "" not in quote_data: + if '' not in quote_data: quote_data_message_list.append(platform_message.Plain(quote_data)) else: # 引用消息展开 quote_data_xml = ET.fromstring(quote_data) - if quote_data_xml.find("img"): + if quote_data_xml.find('img'): quote_data_message_list.extend(await self._handler_image(None, quote_data)) - elif quote_data_xml.find("voicemsg"): + elif quote_data_xml.find('voicemsg'): quote_data_message_list.extend(await self._handler_voice(None, quote_data)) - elif quote_data_xml.find("videomsg"): + elif quote_data_xml.find('videomsg'): quote_data_message_list.extend(await self._handler_default(None, quote_data)) # 先不处理 else: # appmsg quote_data_message_list.extend(await self._handler_compound(None, quote_data)) except Exception as e: - self.logger.error(f"处理引用消息异常 expcetion:{e}") + self.logger.error(f'处理引用消息异常 expcetion:{e}') quote_data_message_list.append(platform_message.Plain(quote_data)) message_list.append( platform_message.Quote( @@ -315,15 +273,11 @@ class WeChatPadMessageConverter(adapter.MessageConverter): return platform_message.MessageChain(message_list) - async def _handler_compound_file( - self, - message: dict, - xml_data: ET.Element - ) -> platform_message.MessageChain: + async def _handler_compound_file(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain: """处理文件消息 (data_type=6)""" file_data = xml_data.find('.//appmsg') - if file_data.findtext('.//type', "") == "74": + if file_data.findtext('.//type', '') == '74': return None else: @@ -346,22 +300,21 @@ class WeChatPadMessageConverter(adapter.MessageConverter): file_data = self.bot.cdn_download(aeskey=aeskey, file_type=5, file_url=cdnthumburl) - file_base64 = file_data["Data"]['FileData'] + file_base64 = file_data['Data']['FileData'] # print(file_data) - file_size = file_data["Data"]['TotalSize'] + file_size = file_data['Data']['TotalSize'] # print(file_base64) - return platform_message.MessageChain([ - platform_message.WeChatFile(file_id=file_id, file_name=file_name, file_size=file_size, - file_base64=file_base64), - platform_message.WeChatForwardFile(xml_data=xml_data_str) - ]) + return platform_message.MessageChain( + [ + platform_message.WeChatFile( + file_id=file_id, file_name=file_name, file_size=file_size, file_base64=file_base64 + ), + platform_message.WeChatForwardFile(xml_data=xml_data_str), + ] + ) - async def _handler_compound_link( - self, - message: dict, - xml_data: ET.Element - ) -> platform_message.MessageChain: + async def _handler_compound_link(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain: """处理链接消息(如公众号文章、外部网页)""" message_list = [] try: @@ -374,56 +327,38 @@ class WeChatPadMessageConverter(adapter.MessageConverter): link_title=appmsg.findtext('title', ''), link_desc=appmsg.findtext('des', ''), link_url=appmsg.findtext('url', ''), - link_thumb_url=appmsg.findtext("thumburl", '') # 这个字段拿不到 + link_thumb_url=appmsg.findtext('thumburl', ''), # 这个字段拿不到 ) ) # 还没有发链接的接口, 暂时还需要自己构造appmsg, 先用WeChatAppMsg。 - message_list.append( - platform_message.WeChatAppMsg( - app_msg=ET.tostring(appmsg, encoding='unicode') - ) - ) + message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg, encoding='unicode'))) except Exception as e: - self.logger.error(f"解析链接消息失败: {str(e)}") + self.logger.error(f'解析链接消息失败: {str(e)}') return platform_message.MessageChain(message_list) async def _handler_compound_mini_program( - self, - message: dict, - xml_data: ET.Element + self, message: dict, xml_data: ET.Element ) -> platform_message.MessageChain: """处理小程序消息(如小程序卡片、服务通知)""" xml_data_str = ET.tostring(xml_data, encoding='unicode') - return platform_message.MessageChain([ - platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str) - ]) + return platform_message.MessageChain([platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)]) - async def _handler_default( - self, - message: Optional[dict], - content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_default(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理未知消息类型""" if message: - msg_type = message["msg_type"] + msg_type = message['msg_type'] else: - msg_type = "" - return platform_message.MessageChain([ - platform_message.Unknown(text=f"[未知消息类型 msg_type:{msg_type}]") - ]) + msg_type = '' + return platform_message.MessageChain([platform_message.Unknown(text=f'[未知消息类型 msg_type:{msg_type}]')]) def _handler_compound_unsupported( - self, - message: dict, - xml_data: str, - text: Optional[str] = None + self, message: dict, xml_data: str, text: Optional[str] = None ) -> platform_message.MessageChain: """处理未支持复合消息类型(msg_type=49)子类型""" if not text: - text = f"[xml_data={xml_data}]" + text = f'[xml_data={xml_data}]' content_list = [] - content_list.append( - platform_message.Unknown(text=f"[处理未支持复合消息类型[msg_type=49]|{text}")) + content_list.append(platform_message.Unknown(text=f'[处理未支持复合消息类型[msg_type=49]|{text}')) return platform_message.MessageChain(content_list) @@ -432,7 +367,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter): ats_bot = False try: to_user_name = message['to_user_name']['str'] # 接收方: 所属微信的wxid - raw_content = message["content"]["str"] # 原始消息内容 + raw_content = message['content']['str'] # 原始消息内容 content_no_prefix, _ = self._extract_content_and_sender(raw_content) # 直接艾特机器人(这个有bug,当被引用的消息里面有@bot,会套娃 # ats_bot = ats_bot or (f"@{bot_account_id}" in content_no_prefix) @@ -443,7 +378,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter): msg_source = message.get('msg_source', '') or '' if len(msg_source) > 0: msg_source_data = ET.fromstring(msg_source) - at_user_list = msg_source_data.findtext("atuserlist") or "" + at_user_list = msg_source_data.findtext('atuserlist') or '' ats_bot = ats_bot or (to_user_name in at_user_list) # 引用bot if message.get('msg_type', 0) == 49: @@ -454,7 +389,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter): quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') # 引用消息的原发送者 ats_bot = ats_bot or (quote_id == tousername) except Exception as e: - self.logger.error(f"_ats_bot got except: {e}") + self.logger.error(f'_ats_bot got except: {e}') finally: return ats_bot @@ -463,47 +398,41 @@ class WeChatPadMessageConverter(adapter.MessageConverter): try: # 检查消息开头,如果有 wxid_sbitaz0mt65n22:\n 则删掉 # add: 有些用户的wxid不是上述格式。换成user_name: - regex = re.compile(r"^[a-zA-Z0-9_\-]{5,20}:") - line_split = raw_content.split("\n") + regex = re.compile(r'^[a-zA-Z0-9_\-]{5,20}:') + line_split = raw_content.split('\n') if len(line_split) > 0 and regex.match(line_split[0]): - raw_content = "\n".join(line_split[1:]) - sender_id = line_split[0].strip(":") + raw_content = '\n'.join(line_split[1:]) + sender_id = line_split[0].strip(':') return raw_content, sender_id except Exception as e: - self.logger.error(f"_extract_content_and_sender got except: {e}") + self.logger.error(f'_extract_content_and_sender got except: {e}') finally: return raw_content, None # 是否是群消息 def _is_group_message(self, message: dict) -> bool: from_user_name = message['from_user_name']['str'] - return from_user_name.endswith("@chatroom") + return from_user_name.endswith('@chatroom') class WeChatPadEventConverter(adapter.EventConverter): - def __init__(self, config: dict): self.config = config self.message_converter = WeChatPadMessageConverter(config) - self.logger = logging.getLogger("WeChatPadEventConverter") - + self.logger = logging.getLogger('WeChatPadEventConverter') + @staticmethod - async def yiri2target( - event: platform_events.MessageEvent - ) -> dict: + async def yiri2target(event: platform_events.MessageEvent) -> dict: pass - async def target2yiri( - self, - event: dict, - bot_account_id: str - ) -> platform_events.MessageEvent: - + async def target2yiri(self, event: dict, bot_account_id: str) -> platform_events.MessageEvent: # 排除公众号以及微信团队消息 - if event['from_user_name']['str'].startswith('gh_') \ - or event['from_user_name']['str']=='weixin'\ - or event['from_user_name']['str'] == "newsapp"\ - or event['from_user_name']['str'] == self.config["wxid"]: + if ( + event['from_user_name']['str'].startswith('gh_') + or event['from_user_name']['str'] == 'weixin' + or event['from_user_name']['str'] == 'newsapp' + or event['from_user_name']['str'] == self.config['wxid'] + ): return None message_chain = await self.message_converter.target2yiri(copy.deepcopy(event), bot_account_id) @@ -512,7 +441,7 @@ class WeChatPadEventConverter(adapter.EventConverter): if '@chatroom' in event['from_user_name']['str']: # 找出开头的 wxid_ 字符串,以:结尾 - sender_wxid = event['content']['str'].split(":")[0] + sender_wxid = event['content']['str'].split(':')[0] return platform_events.GroupMessage( sender=platform_entities.GroupMember( @@ -524,13 +453,13 @@ class WeChatPadEventConverter(adapter.EventConverter): name=event['from_user_name']['str'], permission=platform_entities.Permission.Member, ), - special_title="", + special_title='', join_timestamp=0, last_speak_timestamp=0, mute_time_remaining=0, ), message_chain=message_chain, - time=event["create_time"], + time=event['create_time'], source_platform_object=event, ) else: @@ -541,13 +470,13 @@ class WeChatPadEventConverter(adapter.EventConverter): remark='', ), message_chain=message_chain, - time=event["create_time"], + time=event['create_time'], source_platform_object=event, ) class WeChatPadAdapter(adapter.MessagePlatformAdapter): - name: str = "WeChatPad" # 定义适配器名称 + name: str = 'WeChatPad' # 定义适配器名称 bot: WeChatPadClient quart_app: quart.Quart @@ -580,27 +509,21 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): # self.ap.logger.debug(f"Gewechat callback event: {data}") # print(data) - try: event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id) - except Exception as e: - await self.logger.error(f"Error in wechatpad callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in wechatpad callback: {traceback.format_exc()}') if event.__class__ in self.listeners: await self.listeners[event.__class__](event, self) return 'ok' - - async def _handle_message( - self, - message: platform_message.MessageChain, - target_id: str - ): + async def _handle_message(self, message: platform_message.MessageChain, target_id: str): """统一消息处理核心逻辑""" content_list = await self.message_converter.yiri2target(message) # print(content_list) - at_targets = [item["target"] for item in content_list if item["type"] == "at"] + at_targets = [item['target'] for item in content_list if item['type'] == 'at'] # print(at_targets) # 处理@逻辑 at_targets = at_targets or [] @@ -608,7 +531,7 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): if at_targets: member_info = self.bot.get_chatroom_member_detail( target_id, - )["Data"]["member_data"]["chatroom_member_list"] + )['Data']['member_data']['chatroom_member_list'] # 处理消息组件 for msg in content_list: @@ -616,63 +539,51 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): if msg['type'] == 'text' and at_targets: at_nick_name_list = [] for member in member_info: - if member["user_name"] in at_targets: + if member['user_name'] in at_targets: at_nick_name_list.append(f'@{member["nick_name"]}') msg['content'] = f'{" ".join(at_nick_name_list)} {msg["content"]}' # 统一消息派发 handler_map = { 'text': lambda msg: self.bot.send_text_message( - to_wxid=target_id, - message=msg['content'], - ats=at_targets + to_wxid=target_id, message=msg['content'], ats=at_targets ), 'image': lambda msg: self.bot.send_image_message( - to_wxid=target_id, - img_url=msg["image"], - ats = at_targets + to_wxid=target_id, img_url=msg['image'], ats=at_targets ), 'WeChatEmoji': lambda msg: self.bot.send_emoji_message( - to_wxid=target_id, - emoji_md5=msg['emoji_md5'], - emoji_size=msg['emoji_size'] + to_wxid=target_id, emoji_md5=msg['emoji_md5'], emoji_size=msg['emoji_size'] ), - 'voice': lambda msg: self.bot.send_voice_message( to_wxid=target_id, voice_data=msg['data'], - voice_duration=msg["duration"], - voice_forma=msg["forma"], + voice_duration=msg['duration'], + voice_forma=msg['forma'], ), 'WeChatAppMsg': lambda msg: self.bot.send_app_message( to_wxid=target_id, app_message=msg['app_msg'], type=0, ), - 'at': lambda msg: None + 'at': lambda msg: None, } if handler := handler_map.get(msg['type']): handler(msg) # self.ap.logger.warning(f"未处理的消息类型: {ret}") else: - self.ap.logger.warning(f"未处理的消息类型: {msg['type']}") + self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}') continue - async def send_message( - self, - target_type: str, - target_id: str, - message: platform_message.MessageChain - ): + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): """主动发送消息""" return await self._handle_message(message, target_id) async def reply_message( - self, - message_source: platform_events.MessageEvent, - message: platform_message.MessageChain, - quote_origin: bool = False + self, + message_source: platform_events.MessageEvent, + message: platform_message.MessageChain, + quote_origin: bool = False, ): """回复消息""" if message_source.source_platform_object: @@ -683,58 +594,49 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): pass def register_listener( - self, - event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None] + self, + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): self.listeners[event_type] = callback def unregister_listener( - self, - event_type: typing.Type[platform_events.Event], - callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None] + self, + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): pass async def run_async(self): - - if not self.config["admin_key"] and not self.config["token"]: - raise RuntimeError("无wechatpad管理密匙,请填入配置文件后重启") + if not self.config['admin_key'] and not self.config['token']: + raise RuntimeError('无wechatpad管理密匙,请填入配置文件后重启') else: - if self.config["token"]: - self.bot = WeChatPadClient( - self.config['wechatpad_url'], - self.config["token"] - ) + if self.config['token']: + self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token']) data = self.bot.get_login_status() self.ap.logger.info(data) - if data["Code"] == 300 and data["Text"] == "你已退出微信": + if data['Code'] == 300 and data['Text'] == '你已退出微信': response = requests.post( - f"{self.config['wechatpad_url']}/admin/GenAuthKey1?key={self.config['admin_key']}", - json={"Count": 1, "Days": 365} + f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}', + json={'Count': 1, 'Days': 365}, ) if response.status_code != 200: - raise Exception(f"获取token失败: {response.text}") - self.config["token"] = response.json()["Data"][0] + raise Exception(f'获取token失败: {response.text}') + self.config['token'] = response.json()['Data'][0] - elif not self.config["token"]: + elif not self.config['token']: response = requests.post( - f"{self.config['wechatpad_url']}/admin/GenAuthKey1?key={self.config['admin_key']}", - json={"Count": 1, "Days": 365} + f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}', + json={'Count': 1, 'Days': 365}, ) if response.status_code != 200: - raise Exception(f"获取token失败: {response.text}") - self.config["token"] = response.json()["Data"][0] + raise Exception(f'获取token失败: {response.text}') + self.config['token'] = response.json()['Data'][0] - self.bot = WeChatPadClient( - self.config['wechatpad_url'], - self.config["token"], - logger=self.logger - ) - self.ap.logger.info(self.config["token"]) + self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'], logger=self.logger) + self.ap.logger.info(self.config['token']) thread_1 = threading.Event() - def wechat_login_process(): # 不登录,这些先注释掉,避免登陆态尝试拉qrcode。 # login_data =self.bot.get_login_qr() @@ -742,67 +644,54 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): # url = login_data['Data']["QrCodeUrl"] # self.ap.logger.info(login_data) - - profile =self.bot.get_profile() + profile = self.bot.get_profile() self.ap.logger.info(profile) - self.bot_account_id = profile["Data"]["userInfo"]["nickName"]["str"] - self.config["wxid"] = profile["Data"]["userInfo"]["userName"]["str"] + self.bot_account_id = profile['Data']['userInfo']['nickName']['str'] + self.config['wxid'] = profile['Data']['userInfo']['userName']['str'] thread_1.set() - # asyncio.create_task(wechat_login_process) threading.Thread(target=wechat_login_process).start() def connect_websocket_sync() -> None: - thread_1.wait() - uri = f"{self.config['wechatpad_ws']}/GetSyncMsg?key={self.config['token']}" - self.ap.logger.info(f"Connecting to WebSocket: {uri}") + uri = f'{self.config["wechatpad_ws"]}/GetSyncMsg?key={self.config["token"]}' + self.ap.logger.info(f'Connecting to WebSocket: {uri}') + def on_message(ws, message): try: data = json.loads(message) - self.ap.logger.debug(f"Received message: {data}") + self.ap.logger.debug(f'Received message: {data}') # 这里需要确保ws_message是同步的,或者使用asyncio.run调用异步方法 asyncio.run(self.ws_message(data)) except json.JSONDecodeError: - self.ap.logger.error(f"Non-JSON message: {message[:100]}...") + self.ap.logger.error(f'Non-JSON message: {message[:100]}...') def on_error(ws, error): - self.ap.logger.error(f"WebSocket error: {str(error)[:200]}") + self.ap.logger.error(f'WebSocket error: {str(error)[:200]}') def on_close(ws, close_status_code, close_msg): - self.ap.logger.info("WebSocket closed, reconnecting...") + self.ap.logger.info('WebSocket closed, reconnecting...') time.sleep(5) connect_websocket_sync() # 自动重连 def on_open(ws): - self.ap.logger.info("WebSocket connected successfully!") + self.ap.logger.info('WebSocket connected successfully!') ws = websocket.WebSocketApp( - uri, - on_message=on_message, - on_error=on_error, - on_close=on_close, - on_open=on_open - ) - ws.run_forever( - ping_interval=60, - ping_timeout=20 + uri, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open ) + ws.run_forever(ping_interval=60, ping_timeout=20) # 直接调用同步版本(会阻塞) # connect_websocket_sync() # 这行代码会在WebSocket连接断开后才会执行 # self.ap.logger.info("WebSocket client thread started") - thread = threading.Thread( - target=connect_websocket_sync, - name="WebSocketClientThread", - daemon=True - ) + thread = threading.Thread(target=connect_websocket_sync, name='WebSocketClientThread', daemon=True) thread.start() - self.ap.logger.info("WebSocket client thread started") + self.ap.logger.info('WebSocket client thread started') async def kill(self) -> bool: pass diff --git a/pkg/platform/sources/wecom.py b/pkg/platform/sources/wecom.py index f1cc677e..7be05a85 100644 --- a/pkg/platform/sources/wecom.py +++ b/pkg/platform/sources/wecom.py @@ -157,7 +157,7 @@ class WecomAdapter(adapter.MessagePlatformAdapter): token=config['token'], EncodingAESKey=config['EncodingAESKey'], contacts_secret=config['contacts_secret'], - logger=self.logger + logger=self.logger, ) async def reply_message( @@ -201,8 +201,8 @@ class WecomAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = event.receiver_id try: return await callback(await self.event_converter.target2yiri(event), self) - except Exception as e: - await self.logger.error(f"Error in wecom callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in wecom callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('text')(on_message) diff --git a/pkg/platform/sources/wecomcs.py b/pkg/platform/sources/wecomcs.py index aab8d394..da84ac6d 100644 --- a/pkg/platform/sources/wecomcs.py +++ b/pkg/platform/sources/wecomcs.py @@ -145,7 +145,7 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter): secret=config['secret'], token=config['token'], EncodingAESKey=config['EncodingAESKey'], - logger=self.logger + logger=self.logger, ) async def reply_message( @@ -178,8 +178,8 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter): self.bot_account_id = event.receiver_id try: return await callback(await self.event_converter.target2yiri(event), self) - except Exception as e: - await self.logger.error(f"Error in wecomcs callback: {traceback.format_exc()}") + except Exception: + await self.logger.error(f'Error in wecomcs callback: {traceback.format_exc()}') if event_type == platform_events.FriendMessage: self.bot.on_message('text')(on_message) diff --git a/pkg/platform/types/message.py b/pkg/platform/types/message.py index 7dad4145..ecd7cc96 100644 --- a/pkg/platform/types/message.py +++ b/pkg/platform/types/message.py @@ -812,12 +812,14 @@ class File(MessageComponent): def __str__(self): return f'[文件]{self.name}' + class Face(MessageComponent): """系统表情 此处将超级表情骰子/划拳,一同归类于face 当face_type为rps(划拳)时 face_id 对应的是手势 当face_type为dice(骰子)时 face_id 对应的是点数 """ + type: str = 'Face' """表情类型""" face_type: str = 'face' @@ -834,15 +836,15 @@ class Face(MessageComponent): elif self.face_type == 'rps': return f'[表情]{self.face_name}({self.rps_data(self.face_id)})' - - def rps_data(self,face_id): - rps_dict ={ - 1 : "布", - 2 : "剪刀", - 3 : "石头", + def rps_data(self, face_id): + rps_dict = { + 1: '布', + 2: '剪刀', + 3: '石头', } return rps_dict[face_id] + # ================ 个人微信专用组件 ================ @@ -971,5 +973,6 @@ class WeChatFile(MessageComponent): """文件地址""" file_base64: str = '' """base64""" + def __str__(self): - return f'[文件]{self.file_name}' \ No newline at end of file + return f'[文件]{self.file_name}' diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index df2b5487..ff1e4526 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -127,6 +127,7 @@ class Message(pydantic.BaseModel): class MessageChunk(pydantic.BaseModel): """消息""" + resp_message_id: typing.Optional[str] = None """消息id""" @@ -148,7 +149,7 @@ class MessageChunk(pydantic.BaseModel): tool_call_id: typing.Optional[str] = None # tool_calls: typing.Optional[list[ToolCallChunk]] = None - + is_final: bool = False def readable_str(self) -> str: @@ -210,6 +211,7 @@ class ToolCallChunk(pydantic.BaseModel): function: FunctionCall """函数调用""" + class Prompt(pydantic.BaseModel): """供AI使用的Prompt""" diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 7830e522..1545a2e4 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -71,19 +71,18 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. Returns: - llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk]: 返回消息对象 + llm_entities.Message: 返回消息对象 """ pass @abc.abstractmethod async def invoke_llm_stream( - self, - query: core_entities.Query, - model: RuntimeLLMModel, - messages: typing.List[llm_entities.Message], - funcs: typing.List[tools_entities.LLMFunction] = None, - stream: bool = False, - extra_args: dict[str, typing.Any] = {}, + self, + query: core_entities.Query, + model: RuntimeLLMModel, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.MessageChunk: """调用API @@ -94,6 +93,6 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. Returns: - llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk]: 返回消息对象 + typing.AsyncGenerator[llm_entities.MessageChunk]: 返回消息对象 """ pass diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index 6e72d78e..f05af8c3 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -8,7 +8,7 @@ import openai.types.chat.chat_completion as chat_completion import httpx from .. import errors, requester -from ....core import entities as core_entities, app +from ....core import entities as core_entities from ... import entities as llm_entities from ...tools import entities as tools_entities @@ -129,12 +129,10 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): req_messages: list[dict], use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, - stream: bool = False, extra_args: dict[str, typing.Any] = {}, - ) ->llm_entities.MessageChunk: + ) -> llm_entities.MessageChunk: self.client.api_key = use_model.token_mgr.get_token() - args = {} args['model'] = use_model.model_entity.name @@ -158,43 +156,42 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): args['messages'] = messages - if stream: - current_content = '' - args['stream'] = True - chunk_idx = 0 - self.is_content = False - tool_calls_map: dict[str, llm_entities.ToolCall] = {} - pipeline_config = query.pipeline_config - async for chunk in self._req_stream(args, extra_body=extra_args): - # 处理流式消息 - delta_message = await self._make_msg_chunk(pipeline_config, chunk, chunk_idx) - if delta_message.content: - current_content += delta_message.content - delta_message.content = current_content - # delta_message.all_content = current_content - if delta_message.tool_calls: - for tool_call in delta_message.tool_calls: - if tool_call.id not in tool_calls_map: - tool_calls_map[tool_call.id] = llm_entities.ToolCall( - id=tool_call.id, - type=tool_call.type, - function=llm_entities.FunctionCall( - name=tool_call.function.name if tool_call.function else '', arguments='' - ), - ) - if tool_call.function and tool_call.function.arguments: - # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 - tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments + current_content = '' + args['stream'] = True + chunk_idx = 0 + self.is_content = False + tool_calls_map: dict[str, llm_entities.ToolCall] = {} + pipeline_config = query.pipeline_config + async for chunk in self._req_stream(args, extra_body=extra_args): + # 处理流式消息 + delta_message = await self._make_msg_chunk(pipeline_config, chunk, chunk_idx) + if delta_message.content: + current_content += delta_message.content + delta_message.content = current_content + # delta_message.all_content = current_content + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + if tool_call.id not in tool_calls_map: + tool_calls_map[tool_call.id] = llm_entities.ToolCall( + id=tool_call.id, + type=tool_call.type, + function=llm_entities.FunctionCall( + name=tool_call.function.name if tool_call.function else '', arguments='' + ), + ) + if tool_call.function and tool_call.function.arguments: + # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 + tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments - chunk_idx += 1 - chunk_choices = getattr(chunk, 'choices', None) - if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): - delta_message.is_final = True - delta_message.content = current_content + chunk_idx += 1 + chunk_choices = getattr(chunk, 'choices', None) + if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): + delta_message.is_final = True + delta_message.content = current_content - if chunk_idx % 64 == 0 or delta_message.is_final: - yield delta_message - # return + if chunk_idx % 64 == 0 or delta_message.is_final: + yield delta_message + # return async def _closure( self, @@ -202,7 +199,6 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): req_messages: list[dict], use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, - stream: bool = False, extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() @@ -289,7 +285,6 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): model: requester.RuntimeLLMModel, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, - stream: bool = False, extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.MessageChunk: req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 @@ -309,7 +304,6 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): req_messages=req_messages, use_model=model, use_funcs=funcs, - stream=stream, extra_args=extra_args, ): yield item diff --git a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py index 2a618c9f..1c19a534 100644 --- a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py @@ -12,7 +12,6 @@ import re import openai.types.chat.chat_completion as chat_completion - class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): """Gitee AI ChatCompletions API 请求器""" @@ -20,7 +19,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): 'base_url': 'https://ai.gitee.com/v1', 'timeout': 120, } - is_think:bool = False + is_think: bool = False async def _closure( self, @@ -52,15 +51,14 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): pipeline_config = query.pipeline_config - message = await self._make_msg(resp,pipeline_config) + message = await self._make_msg(resp, pipeline_config) return message - async def _make_msg( - self, - chat_completion: chat_completion.ChatCompletion, - pipeline_config: dict[str, typing.Any] = {'trigger': {'misc': {'remove_think': False}}}, + self, + chat_completion: chat_completion.ChatCompletion, + pipeline_config: dict[str, typing.Any] = {'trigger': {'misc': {'remove_think': False}}}, ) -> llm_entities.Message: chatcmpl_message = chat_completion.choices[0].message.model_dump() # print(chatcmpl_message.keys(), chatcmpl_message.values()) @@ -73,23 +71,25 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): # deepseek的reasoner模型 if pipeline_config['trigger'].get('misc', '').get('remove_think'): - chatcmpl_message['content'] = re.sub(r'.*?', '', chatcmpl_message['content'], flags=re.DOTALL) + chatcmpl_message['content'] = re.sub( + r'.*?', '', chatcmpl_message['content'], flags=re.DOTALL + ) else: if reasoning_content is not None: - chatcmpl_message['content'] = '\n' + reasoning_content + '\n\n' + chatcmpl_message['content'] + chatcmpl_message['content'] = ( + '\n' + reasoning_content + '\n\n' + chatcmpl_message['content'] + ) message = llm_entities.Message(**chatcmpl_message) return message - async def _make_msg_chunk( self, pipeline_config: dict[str, typing.Any], chat_completion: chat_completion.ChatCompletion, idx: int, ) -> llm_entities.MessageChunk: - # 处理流式chunk和完整响应的差异 # print(chat_completion.choices[0]) if hasattr(chat_completion, 'choices'): @@ -104,7 +104,6 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): if 'role' not in delta or delta['role'] is None: delta['role'] = 'assistant' - reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None delta['content'] = '' if delta['content'] is None else delta['content'] @@ -115,7 +114,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): if delta['content'] == '': self.is_think = True delta['content'] = '' - if delta['content'] == rf'': + if delta['content'] == r'': self.is_think = False delta['content'] = '' if not self.is_think: @@ -126,7 +125,6 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): if reasoning_content is not None: delta['content'] += reasoning_content - message = llm_entities.MessageChunk(**delta) return message @@ -137,7 +135,6 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): req_messages: list[dict], use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, - stream: bool = False, extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: self.client.api_key = use_model.token_mgr.get_token() @@ -165,44 +162,38 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): args['messages'] = messages - if stream: - current_content = '' - args["stream"] = True - chunk_idx = 0 - self.is_content = False - tool_calls_map: dict[str, llm_entities.ToolCall] = {} - pipeline_config = query.pipeline_config - async for chunk in self._req_stream(args, extra_body=extra_args): - # 处理流式消息 - delta_message = await self._make_msg_chunk(pipeline_config,chunk,chunk_idx) - if delta_message.content: - current_content += delta_message.content - delta_message.content = current_content - # delta_message.all_content = current_content - if delta_message.tool_calls: - for tool_call in delta_message.tool_calls: - if tool_call.id not in tool_calls_map: - tool_calls_map[tool_call.id] = llm_entities.ToolCall( - id=tool_call.id, - type=tool_call.type, - function=llm_entities.FunctionCall( - name=tool_call.function.name if tool_call.function else '', - arguments='' - ), - ) - if tool_call.function and tool_call.function.arguments: - # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 - tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments - - - chunk_idx += 1 - chunk_choices = getattr(chunk, 'choices', None) - if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): - delta_message.is_final = True - delta_message.content = current_content - - if chunk_idx % 64 == 0 or delta_message.is_final: - - yield delta_message + current_content = '' + args['stream'] = True + chunk_idx = 0 + self.is_content = False + tool_calls_map: dict[str, llm_entities.ToolCall] = {} + pipeline_config = query.pipeline_config + async for chunk in self._req_stream(args, extra_body=extra_args): + # 处理流式消息 + delta_message = await self._make_msg_chunk(pipeline_config, chunk, chunk_idx) + if delta_message.content: + current_content += delta_message.content + delta_message.content = current_content + # delta_message.all_content = current_content + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + if tool_call.id not in tool_calls_map: + tool_calls_map[tool_call.id] = llm_entities.ToolCall( + id=tool_call.id, + type=tool_call.type, + function=llm_entities.FunctionCall( + name=tool_call.function.name if tool_call.function else '', arguments='' + ), + ) + if tool_call.function and tool_call.function.arguments: + # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 + tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments + chunk_idx += 1 + chunk_choices = getattr(chunk, 'choices', None) + if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): + delta_message.is_final = True + delta_message.content = current_content + if chunk_idx % 64 == 0 or delta_message.is_final: + yield delta_message diff --git a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py index 1a303d22..b98ae7ff 100644 --- a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py +++ b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py @@ -165,11 +165,10 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester): return message async def _req_stream( - self, - args: dict, - extra_body: dict = {}, + self, + args: dict, + extra_body: dict = {}, ) -> chat_completion.ChatCompletion: - async for chunk in await self.client.chat.completions.create(**args, extra_body=extra_body): yield chunk @@ -179,7 +178,6 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester): chat_completion: chat_completion.ChatCompletion, idx: int, ) -> llm_entities.MessageChunk: - # 处理流式chunk和完整响应的差异 # print(chat_completion.choices[0]) if hasattr(chat_completion, 'choices'): @@ -195,7 +193,6 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester): if 'role' not in delta or delta['role'] is None: delta['role'] = 'assistant' - reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None delta['content'] = '' if delta['content'] is None else delta['content'] @@ -203,13 +200,13 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester): # deepseek的reasoner模型 if pipeline_config['trigger'].get('misc', '').get('remove_think'): - if reasoning_content is not None : + if reasoning_content is not None: pass else: delta['content'] = delta['content'] else: if reasoning_content is not None and idx == 0: - delta['content'] += f'\n{reasoning_content}' + delta['content'] += f'\n{reasoning_content}' elif reasoning_content is None: if self.is_content: delta['content'] = delta['content'] @@ -219,7 +216,6 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester): else: delta['content'] += reasoning_content - message = llm_entities.MessageChunk(**delta) return message @@ -230,7 +226,6 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester): req_messages: list[dict], use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, - stream: bool = False, extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: self.client.api_key = use_model.token_mgr.get_token() @@ -258,48 +253,42 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester): args['messages'] = messages - if stream: - current_content = '' - args["stream"] = True - chunk_idx = 0 - self.is_content = False - tool_calls_map: dict[str, llm_entities.ToolCall] = {} - pipeline_config = query.pipeline_config - async for chunk in self._req_stream(args, extra_body=extra_args): - # 处理流式消息 - delta_message = await self._make_msg_chunk(pipeline_config,chunk,chunk_idx) - if delta_message.content: - current_content += delta_message.content - delta_message.content = current_content - # delta_message.all_content = current_content - if delta_message.tool_calls: - for tool_call in delta_message.tool_calls: - if tool_call.id not in tool_calls_map: - tool_calls_map[tool_call.id] = llm_entities.ToolCall( - id=tool_call.id, - type=tool_call.type, - function=llm_entities.FunctionCall( - name=tool_call.function.name if tool_call.function else '', - arguments='' - ), - ) - if tool_call.function and tool_call.function.arguments: - # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 - tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments - - - chunk_idx += 1 - chunk_choices = getattr(chunk, 'choices', None) - if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): - delta_message.is_final = True - delta_message.content = current_content - - if chunk_idx % 64 == 0 or delta_message.is_final: - - yield delta_message - # return + current_content = '' + args['stream'] = True + chunk_idx = 0 + self.is_content = False + tool_calls_map: dict[str, llm_entities.ToolCall] = {} + pipeline_config = query.pipeline_config + async for chunk in self._req_stream(args, extra_body=extra_args): + # 处理流式消息 + delta_message = await self._make_msg_chunk(pipeline_config, chunk, chunk_idx) + if delta_message.content: + current_content += delta_message.content + delta_message.content = current_content + # delta_message.all_content = current_content + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + if tool_call.id not in tool_calls_map: + tool_calls_map[tool_call.id] = llm_entities.ToolCall( + id=tool_call.id, + type=tool_call.type, + function=llm_entities.FunctionCall( + name=tool_call.function.name if tool_call.function else '', arguments='' + ), + ) + if tool_call.function and tool_call.function.arguments: + # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 + tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments + chunk_idx += 1 + chunk_choices = getattr(chunk, 'choices', None) + if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): + delta_message.is_final = True + delta_message.content = current_content + if chunk_idx % 64 == 0 or delta_message.is_final: + yield delta_message + # return async def invoke_llm( self, @@ -340,16 +329,14 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester): except openai.APIError as e: raise errors.RequesterError(f'请求错误: {e.message}') - async def invoke_llm_stream( self, query: core_entities.Query, model: requester.RuntimeLLMModel, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, - stream: bool = False, extra_args: dict[str, typing.Any] = {}, - ) -> llm_entities.MessageChunk: + ) -> llm_entities.MessageChunk: req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 for m in messages: msg_dict = m.dict(exclude_none=True) @@ -367,7 +354,6 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester): req_messages=req_messages, use_model=model, use_funcs=funcs, - stream=stream, extra_args=extra_args, ): yield item @@ -386,4 +372,4 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester): except openai.RateLimitError as e: raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') except openai.APIError as e: - raise errors.RequesterError(f'请求错误: {e.message}') \ No newline at end of file + raise errors.RequesterError(f'请求错误: {e.message}') diff --git a/pkg/provider/modelmgr/requesters/ppiochatcmpl.py b/pkg/provider/modelmgr/requesters/ppiochatcmpl.py index 85b321a7..46da6e01 100644 --- a/pkg/provider/modelmgr/requesters/ppiochatcmpl.py +++ b/pkg/provider/modelmgr/requesters/ppiochatcmpl.py @@ -5,8 +5,8 @@ import typing from . import chatcmpl import openai.types.chat.chat_completion as chat_completion -from .. import errors, requester -from ....core import entities as core_entities, app +from .. import requester +from ....core import entities as core_entities from ... import entities as llm_entities from ...tools import entities as tools_entities import re @@ -25,9 +25,9 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): is_think: bool = False async def _make_msg( - self, - chat_completion: chat_completion.ChatCompletion, - pipeline_config: dict[str, typing.Any] = {'trigger': {'misc': {'remove_think': False}}}, + self, + chat_completion: chat_completion.ChatCompletion, + pipeline_config: dict[str, typing.Any] = {'trigger': {'misc': {'remove_think': False}}}, ) -> llm_entities.Message: chatcmpl_message = chat_completion.choices[0].message.model_dump() # print(chatcmpl_message.keys(), chatcmpl_message.values()) @@ -40,21 +40,24 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): # deepseek的reasoner模型 if pipeline_config['trigger'].get('misc', '').get('remove_think'): - chatcmpl_message['content'] = re.sub(r'.*?', '', chatcmpl_message['content'], flags=re.DOTALL) + chatcmpl_message['content'] = re.sub( + r'.*?', '', chatcmpl_message['content'], flags=re.DOTALL + ) else: if reasoning_content is not None: - chatcmpl_message['content'] = '\n' + reasoning_content + '\n\n' + chatcmpl_message['content'] + chatcmpl_message['content'] = ( + '\n' + reasoning_content + '\n\n' + chatcmpl_message['content'] + ) message = llm_entities.Message(**chatcmpl_message) return message - async def _make_msg_chunk( - self, - pipeline_config: dict[str, typing.Any], - chat_completion: chat_completion.ChatCompletion, - idx: int, + self, + pipeline_config: dict[str, typing.Any], + chat_completion: chat_completion.ChatCompletion, + idx: int, ) -> llm_entities.MessageChunk: # 处理流式chunk和完整响应的差异 # print(chat_completion.choices[0]) @@ -80,7 +83,7 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): if '' in delta['content']: self.is_think = True delta['content'] = '' - if rf'' in delta['content']: + if r'' in delta['content']: self.is_think = False delta['content'] = '' if not self.is_think: @@ -95,15 +98,13 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): return message - async def _closure_stream( - self, - query: core_entities.Query, - req_messages: list[dict], - use_model: requester.RuntimeLLMModel, - use_funcs: list[tools_entities.LLMFunction] = None, - stream: bool = False, - extra_args: dict[str, typing.Any] = {}, + self, + query: core_entities.Query, + req_messages: list[dict], + use_model: requester.RuntimeLLMModel, + use_funcs: list[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: self.client.api_key = use_model.token_mgr.get_token() @@ -130,40 +131,38 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): args['messages'] = messages - if stream: - current_content = '' - args["stream"] = True - chunk_idx = 0 - self.is_content = False - tool_calls_map: dict[str, llm_entities.ToolCall] = {} - pipeline_config = query.pipeline_config - async for chunk in self._req_stream(args, extra_body=extra_args): - # 处理流式消息 - delta_message = await self._make_msg_chunk(pipeline_config, chunk, chunk_idx) - if delta_message.content: - current_content += delta_message.content - delta_message.content = current_content - # delta_message.all_content = current_content - if delta_message.tool_calls: - for tool_call in delta_message.tool_calls: - if tool_call.id not in tool_calls_map: - tool_calls_map[tool_call.id] = llm_entities.ToolCall( - id=tool_call.id, - type=tool_call.type, - function=llm_entities.FunctionCall( - name=tool_call.function.name if tool_call.function else '', - arguments='' - ), - ) - if tool_call.function and tool_call.function.arguments: - # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 - tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments + current_content = '' + args['stream'] = True + chunk_idx = 0 + self.is_content = False + tool_calls_map: dict[str, llm_entities.ToolCall] = {} + pipeline_config = query.pipeline_config + async for chunk in self._req_stream(args, extra_body=extra_args): + # 处理流式消息 + delta_message = await self._make_msg_chunk(pipeline_config, chunk, chunk_idx) + if delta_message.content: + current_content += delta_message.content + delta_message.content = current_content + # delta_message.all_content = current_content + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + if tool_call.id not in tool_calls_map: + tool_calls_map[tool_call.id] = llm_entities.ToolCall( + id=tool_call.id, + type=tool_call.type, + function=llm_entities.FunctionCall( + name=tool_call.function.name if tool_call.function else '', arguments='' + ), + ) + if tool_call.function and tool_call.function.arguments: + # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 + tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments - chunk_idx += 1 - chunk_choices = getattr(chunk, 'choices', None) - if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): - delta_message.is_final = True - delta_message.content = current_content + chunk_idx += 1 + chunk_choices = getattr(chunk, 'choices', None) + if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): + delta_message.is_final = True + delta_message.content = current_content - if chunk_idx % 64 == 0 or delta_message.is_final: - yield delta_message + if chunk_idx % 64 == 0 or delta_message.is_final: + yield delta_message diff --git a/pkg/provider/runners/difysvapi.py b/pkg/provider/runners/difysvapi.py index 8182cc54..40a3140c 100644 --- a/pkg/provider/runners/difysvapi.py +++ b/pkg/provider/runners/difysvapi.py @@ -348,7 +348,9 @@ class DifyServiceAPIRunner(runner.RequestRunner): except AttributeError: is_stream = False - batch_pending_index = 0 + _ = is_stream + + # batch_pending_index = 0 plain_text, image_ids = await self._preprocess_user_message(query) diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index b70d4157..30c48cf6 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -63,8 +63,7 @@ class LocalAgentRunner(runner.RequestRunner): id=tool_call.id, type=tool_call.type, function=llm_entities.FunctionCall( - name=tool_call.function.name if tool_call.function else '', - arguments='' + name=tool_call.function.name if tool_call.function else '', arguments='' ), ) if tool_call.function and tool_call.function.arguments: diff --git a/pkg/utils/image.py b/pkg/utils/image.py index f69d29d2..d9518e12 100644 --- a/pkg/utils/image.py +++ b/pkg/utils/image.py @@ -204,9 +204,9 @@ async def get_slack_image_to_base64(pic_url: str, bot_token: str): try: async with aiohttp.ClientSession() as session: async with session.get(pic_url, headers=headers) as resp: - mime_type = resp.headers.get("Content-Type", "application/octet-stream") + mime_type = resp.headers.get('Content-Type', 'application/octet-stream') file_bytes = await resp.read() - base64_str = base64.b64encode(file_bytes).decode("utf-8") - return f"data:{mime_type};base64,{base64_str}" + base64_str = base64.b64encode(file_bytes).decode('utf-8') + return f'data:{mime_type};base64,{base64_str}' except Exception as e: - raise (e) \ No newline at end of file + raise (e) diff --git a/pkg/utils/importutil.py b/pkg/utils/importutil.py index 8acc5c45..1933d611 100644 --- a/pkg/utils/importutil.py +++ b/pkg/utils/importutil.py @@ -32,7 +32,7 @@ def import_dir(path: str): rel_path = full_path.replace(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '') rel_path = rel_path[1:] rel_path = rel_path.replace('/', '.')[:-3] - rel_path = rel_path.replace("\\",".") + rel_path = rel_path.replace('\\', '.') importlib.import_module(rel_path)