From 055b389353334593b5cb5e73686cebfc2de43bcb Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sat, 10 May 2025 18:04:58 +0800 Subject: [PATCH] style: restrict line-length --- .ruff.toml | 3 + libs/dingtalk_api/api.py | 24 +- libs/official_account_api/api.py | 43 +--- libs/qq_official_api/api.py | 18 +- libs/slack_api/api.py | 160 ++++++------- libs/slack_api/slackevent.py | 96 ++++---- libs/wecom_api/WXBizMsgCrypt3.py | 7 +- libs/wecom_api/api.py | 63 ++--- libs/wecom_customer_service_api/api.py | 219 +++++++++--------- .../wecomcsevent.py | 28 ++- pkg/api/http/controller/group.py | 4 +- pkg/api/http/controller/groups/logs.py | 6 +- pkg/api/http/controller/groups/pipelines.py | 12 +- .../controller/groups/platform/adapters.py | 18 +- pkg/api/http/controller/groups/plugins.py | 4 +- .../http/controller/groups/provider/models.py | 4 +- .../controller/groups/provider/requesters.py | 18 +- pkg/api/http/controller/groups/stats.py | 4 +- pkg/api/http/controller/groups/system.py | 12 +- pkg/api/http/controller/groups/user.py | 12 +- pkg/api/http/controller/main.py | 11 +- pkg/api/http/service/bot.py | 28 +-- pkg/api/http/service/model.py | 25 +- pkg/api/http/service/pipeline.py | 20 +- pkg/api/http/service/user.py | 12 +- pkg/command/cmdmgr.py | 26 +-- pkg/command/operator.py | 4 +- pkg/command/operators/cmd.py | 12 +- pkg/command/operators/delc.py | 33 +-- pkg/command/operators/func.py | 4 +- pkg/command/operators/help.py | 4 +- pkg/command/operators/last.py | 27 +-- pkg/command/operators/list.py | 16 +- pkg/command/operators/model.py | 32 +-- pkg/command/operators/next.py | 27 +-- pkg/command/operators/ollama.py | 40 +--- pkg/command/operators/plugin.py | 136 +++-------- pkg/command/operators/prompt.py | 8 +- pkg/command/operators/resend.py | 8 +- pkg/command/operators/reset.py | 4 +- pkg/command/operators/update.py | 8 +- pkg/command/operators/version.py | 4 +- pkg/config/manager.py | 4 +- pkg/core/app.py | 4 +- pkg/core/bootutils/log.py | 8 +- pkg/core/entities.py | 23 +- .../m001_sensitive_word_migration.py | 10 +- .../m002_openai_config_migration.py | 4 +- pkg/core/migrations/m007_qcg_center_url.py | 4 +- .../m008_ad_fixwin_config_migrate.py | 8 +- pkg/core/migrations/m013_http_api_config.py | 5 +- .../m017_dify_api_timeout_params.py | 3 +- .../migrations/m023_siliconflow_config.py | 4 +- .../migrations/m033_dify_thinking_config.py | 9 +- .../m034_gewechat_file_url_config.py | 4 +- .../migrations/m038_tg_dingtalk_markdown.py | 11 +- .../m039_modelscope_cfg_completion.py | 15 +- pkg/core/migrations/m040_ppio_config.py | 15 +- pkg/core/taskmgr.py | 36 +-- pkg/discover/engine.py | 45 +--- pkg/entity/persistence/bot.py | 4 +- pkg/entity/persistence/model.py | 4 +- pkg/entity/persistence/pipeline.py | 8 +- pkg/entity/persistence/plugin.py | 4 +- pkg/entity/persistence/user.py | 4 +- pkg/persistence/databases/sqlite.py | 4 +- pkg/persistence/mgr.py | 36 +-- pkg/pipeline/bansess/bansess.py | 12 +- pkg/pipeline/cntfilter/cntfilter.py | 32 +-- pkg/pipeline/cntfilter/filter.py | 4 +- .../cntfilter/filters/baiduexamine.py | 12 +- pkg/pipeline/cntfilter/filters/banwords.py | 8 +- pkg/pipeline/cntfilter/filters/cntignore.py | 4 +- pkg/pipeline/controller.py | 22 +- pkg/pipeline/longtext/longtext.py | 20 +- pkg/pipeline/longtext/strategies/forward.py | 4 +- pkg/pipeline/longtext/strategies/image.py | 17 +- pkg/pipeline/longtext/strategy.py | 4 +- pkg/pipeline/msgtrun/msgtrun.py | 8 +- pkg/pipeline/pipelinemgr.py | 62 ++--- pkg/pipeline/preproc/preproc.py | 26 +-- pkg/pipeline/process/handlers/chat.py | 28 +-- pkg/pipeline/process/handlers/command.py | 41 +--- pkg/pipeline/ratelimit/algos/fixedwin.py | 4 +- pkg/pipeline/respback/respback.py | 12 +- pkg/pipeline/resprule/resprule.py | 16 +- pkg/pipeline/resprule/rules/atbot.py | 5 +- pkg/pipeline/resprule/rules/random.py | 4 +- pkg/pipeline/wrapper/wrapper.py | 56 ++--- pkg/platform/adapter.py | 12 +- pkg/platform/botmgr.py | 24 +- pkg/platform/sources/aiocqhttp.py | 54 +---- pkg/platform/sources/dingtalk.py | 20 +- pkg/platform/sources/discord.py | 32 +-- pkg/platform/sources/gewechat.py | 216 +++++------------ pkg/platform/sources/lark.py | 91 ++------ pkg/platform/sources/nakuru.py | 63 ++--- pkg/platform/sources/officialaccount.py | 32 +-- pkg/platform/sources/qqbotpy.py | 110 ++------- pkg/platform/sources/qqofficial.py | 50 +--- pkg/platform/sources/slack.py | 133 ++++------- pkg/platform/sources/telegram.py | 42 +--- pkg/platform/sources/wecom.py | 78 ++----- pkg/platform/sources/wecomcs.py | 134 +++++------ pkg/platform/types/base.py | 5 +- pkg/platform/types/events.py | 8 +- pkg/platform/types/message.py | 66 ++---- pkg/plugin/context.py | 15 +- pkg/plugin/installers/github.py | 8 +- pkg/plugin/loaders/classic.py | 44 +--- pkg/plugin/loaders/manifest.py | 24 +- pkg/plugin/manager.py | 60 ++--- pkg/provider/entities.py | 12 +- pkg/provider/modelmgr/modelmgr.py | 24 +- .../modelmgr/requesters/anthropicmsgs.py | 16 +- pkg/provider/modelmgr/requesters/chatcmpl.py | 22 +- .../modelmgr/requesters/modelscopechatcmpl.py | 111 +++++---- .../modelmgr/requesters/ollamachat.py | 17 +- .../modelmgr/requesters/ppiochatcmpl.py | 7 +- pkg/provider/runner.py | 4 +- pkg/provider/runners/dashscopeapi.py | 24 +- pkg/provider/runners/difysvapi.py | 69 ++---- pkg/provider/runners/localagent.py | 16 +- pkg/provider/session/sessionmgr.py | 5 +- pkg/provider/tools/loader.py | 4 +- pkg/provider/tools/loaders/mcp.py | 28 +-- pkg/provider/tools/loaders/plugin.py | 4 +- pkg/provider/tools/toolmgr.py | 16 +- pkg/utils/announce.py | 10 +- pkg/utils/image.py | 12 +- pkg/utils/importutil.py | 4 +- pkg/utils/ip.py | 4 +- pkg/utils/proxy.py | 18 +- pkg/utils/version.py | 22 +- 134 files changed, 1096 insertions(+), 2595 deletions(-) diff --git a/.ruff.toml b/.ruff.toml index 992262ff..5fb48089 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -1,3 +1,6 @@ + +line-length = 120 + [lint] ignore = [ diff --git a/libs/dingtalk_api/api.py b/libs/dingtalk_api/api.py index 883b7455..b66d72a5 100644 --- a/libs/dingtalk_api/api.py +++ b/libs/dingtalk_api/api.py @@ -25,9 +25,7 @@ class DingTalkClient: self.secret = client_secret # 在 DingTalkClient 中传入自己作为参数,避免循环导入 self.EchoTextHandler = EchoTextHandler(self) - self.client.register_callback_handler( - dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler - ) + self.client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler) self._message_handlers = { 'example': [], } @@ -86,9 +84,7 @@ class DingTalkClient: if response.status_code == 200: file_bytes = response.content - base64_str = base64.b64encode(file_bytes).decode( - 'utf-8' - ) # 返回字符串格式 + base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式 return base64_str else: raise Exception('获取文件失败') @@ -151,9 +147,7 @@ class DingTalkClient: for handler in self._message_handlers[msg_type]: await handler(event) - async def get_message( - self, incoming_message: dingtalk_stream.chatbot.ChatbotMessage - ): + async def get_message(self, incoming_message: dingtalk_stream.chatbot.ChatbotMessage): try: # print(json.dumps(incoming_message.to_dict(), indent=4, ensure_ascii=False)) message_data = { @@ -170,9 +164,7 @@ class DingTalkClient: if 'text' in item: message_data['Content'] = item['text'] if incoming_message.get_image_list()[0]: - message_data['Picture'] = await self.download_image( - incoming_message.get_image_list()[0] - ) + message_data['Picture'] = await self.download_image(incoming_message.get_image_list()[0]) message_data['Type'] = 'text' elif incoming_message.message_type == 'text': @@ -180,15 +172,11 @@ class DingTalkClient: message_data['Type'] = 'text' elif incoming_message.message_type == 'picture': - message_data['Picture'] = await self.download_image( - incoming_message.get_image_list()[0] - ) + message_data['Picture'] = await self.download_image(incoming_message.get_image_list()[0]) message_data['Type'] = 'image' elif incoming_message.message_type == 'audio': - message_data['Audio'] = await self.get_audio_url( - incoming_message.to_dict()['content']['downloadCode'] - ) + message_data['Audio'] = await self.get_audio_url(incoming_message.to_dict()['content']['downloadCode']) message_data['Type'] = 'audio' diff --git a/libs/official_account_api/api.py b/libs/official_account_api/api.py index fc392c30..094aeb36 100644 --- a/libs/official_account_api/api.py +++ b/libs/official_account_api/api.py @@ -68,9 +68,7 @@ class OAClient: elif request.method == 'POST': encryt_msg = await request.data wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid) - ret, xml_msg = wxcpt.DecryptMsg( - encryt_msg, msg_signature, timestamp, nonce - ) + ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce) xml_msg = xml_msg.decode('utf-8') if ret != 0: @@ -112,9 +110,7 @@ class OAClient: # create_time=int(time.time()), # content = "请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。" # ) - print( - '请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。' - ) + print('请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。') return '' except Exception: @@ -128,12 +124,8 @@ class OAClient: 'FromUserName': root.find('FromUserName').text, 'CreateTime': int(root.find('CreateTime').text), 'MsgType': root.find('MsgType').text, - 'Content': root.find('Content').text - if root.find('Content') is not None - else None, - 'MsgId': int(root.find('MsgId').text) - if root.find('MsgId') is not None - else None, + 'Content': root.find('Content').text if root.find('Content') is not None else None, + 'MsgId': int(root.find('MsgId').text) if root.find('MsgId') is not None else None, } return message_data @@ -225,9 +217,7 @@ class OAClientForLongerResponse: elif request.method == 'POST': encryt_msg = await request.data wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid) - ret, xml_msg = wxcpt.DecryptMsg( - encryt_msg, msg_signature, timestamp, nonce - ) + ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce) xml_msg = xml_msg.decode('utf-8') if ret != 0: @@ -238,18 +228,12 @@ class OAClientForLongerResponse: from_user = root.find('FromUserName').text to_user = root.find('ToUserName').text - if ( - self.msg_queue.get(from_user) - and self.msg_queue[from_user][0]['content'] - ): + if self.msg_queue.get(from_user) and self.msg_queue[from_user][0]['content']: queue_top = self.msg_queue[from_user].pop(0) queue_content = queue_top['content'] # 弹出用户消息 - if ( - self.user_msg_queue.get(from_user) - and self.user_msg_queue[from_user] - ): + if self.user_msg_queue.get(from_user) and self.user_msg_queue[from_user]: self.user_msg_queue[from_user].pop(0) response_xml = xml_template.format( @@ -268,10 +252,7 @@ class OAClientForLongerResponse: content=self.loading_message, ) - if ( - self.user_msg_queue.get(from_user) - and self.user_msg_queue[from_user][0]['content'] - ): + if self.user_msg_queue.get(from_user) and self.user_msg_queue[from_user][0]['content']: return response_xml else: message_data = await self.get_message(xml_msg) @@ -299,12 +280,8 @@ class OAClientForLongerResponse: 'FromUserName': root.find('FromUserName').text, 'CreateTime': int(root.find('CreateTime').text), 'MsgType': root.find('MsgType').text, - 'Content': root.find('Content').text - if root.find('Content') is not None - else None, - 'MsgId': int(root.find('MsgId').text) - if root.find('MsgId') is not None - else None, + 'Content': root.find('Content').text if root.find('Content') is not None else None, + 'MsgId': int(root.find('MsgId').text) if root.find('MsgId') is not None else None, } return message_data diff --git a/libs/qq_official_api/api.py b/libs/qq_official_api/api.py index 89360881..dbdbcf4a 100644 --- a/libs/qq_official_api/api.py +++ b/libs/qq_official_api/api.py @@ -144,15 +144,9 @@ class QQOfficialClient: 'group_openid': msg.get('d', {}).get('group_openid', {}), } attachments = msg.get('d', {}).get('attachments', []) - image_attachments = [ - attachment['url'] - for attachment in attachments - if await self.is_image(attachment) - ] + image_attachments = [attachment['url'] for attachment in attachments if await self.is_image(attachment)] image_attachments_type = [ - attachment['content_type'] - for attachment in attachments - if await self.is_image(attachment) + attachment['content_type'] for attachment in attachments if await self.is_image(attachment) ] if image_attachments: message_data['image_attachments'] = image_attachments[0] @@ -211,9 +205,7 @@ class QQOfficialClient: else: raise Exception(response.read().decode()) - async def send_channle_group_text_msg( - self, channel_id: str, content: str, msg_id: str - ): + async def send_channle_group_text_msg(self, channel_id: str, content: str, msg_id: str): """发送频道群聊消息""" if not await self.check_access_token(): await self.get_access_token() @@ -235,9 +227,7 @@ class QQOfficialClient: else: raise Exception(response) - async def send_channle_private_text_msg( - self, guild_id: str, content: str, msg_id: str - ): + async def send_channle_private_text_msg(self, guild_id: str, content: str, msg_id: str): """发送频道私聊消息""" if not await self.check_access_token(): await self.get_access_token() diff --git a/libs/slack_api/api.py b/libs/slack_api/api.py index 86239ce9..441692ab 100644 --- a/libs/slack_api/api.py +++ b/libs/slack_api/api.py @@ -1,59 +1,57 @@ import json -from quart import Quart, jsonify,request +from quart import Quart, jsonify, request from slack_sdk.web.async_client import AsyncWebClient from .slackevent import SlackEvent -from typing import Callable, Dict, Any -from pkg.platform.types import events as platform_events, message as platform_message - -class SlackClient(): - - def __init__(self,bot_token:str,signing_secret:str): - - self.bot_token = bot_token - self.signing_secret = signing_secret - self.app = Quart(__name__) - self.client = AsyncWebClient(self.bot_token) - self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']) - self._message_handlers = { - "example":[], - } - self.bot_user_id = None # 避免机器人回复自己的消息 - - async def handle_callback_request(self): - try: - body = await request.get_data() - data = json.loads(body) - if 'type' in data: - if data['type'] == 'url_verification': - return data['challenge'] - - bot_user_id = data.get("event",{}).get("bot_id","") - - if self.bot_user_id and bot_user_id == self.bot_user_id: - return jsonify({'status': 'ok'}) - - - # 处理私信 - if data and data.get("event", {}).get("channel_type") in ["im"]: - event = SlackEvent.from_payload(data) - await self._handle_message(event) - return jsonify({'status': 'ok'}) - - #处理群聊 - if data.get("event",{}).get("type") == 'app_mention': - data.setdefault("event", {})["channel_type"] = "channel" - event = SlackEvent.from_payload(data) - await self._handle_message(event) - return jsonify({'status':'ok'}) - - return jsonify({'status': 'ok'}) - - except Exception as e: - raise(e) - +from typing import Callable +from pkg.platform.types import events as platform_events - async def _handle_message(self, event: SlackEvent): +class SlackClient: + def __init__(self, bot_token: str, signing_secret: str): + self.bot_token = bot_token + self.signing_secret = signing_secret + self.app = Quart(__name__) + self.client = AsyncWebClient(self.bot_token) + self.app.add_url_rule( + '/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'] + ) + self._message_handlers = { + 'example': [], + } + self.bot_user_id = None # 避免机器人回复自己的消息 + + async def handle_callback_request(self): + try: + body = await request.get_data() + data = json.loads(body) + if 'type' in data: + if data['type'] == 'url_verification': + return data['challenge'] + + bot_user_id = data.get('event', {}).get('bot_id', '') + + if self.bot_user_id and bot_user_id == self.bot_user_id: + return jsonify({'status': 'ok'}) + + # 处理私信 + if data and data.get('event', {}).get('channel_type') in ['im']: + event = SlackEvent.from_payload(data) + await self._handle_message(event) + return jsonify({'status': 'ok'}) + + # 处理群聊 + if data.get('event', {}).get('type') == 'app_mention': + data.setdefault('event', {})['channel_type'] = 'channel' + event = SlackEvent.from_payload(data) + await self._handle_message(event) + return jsonify({'status': 'ok'}) + + return jsonify({'status': 'ok'}) + + except Exception as e: + raise (e) + + async def _handle_message(self, event: SlackEvent): """ 处理消息事件。 """ @@ -62,50 +60,38 @@ class SlackClient(): for handler in self._message_handlers[msg_type]: await handler(event) - def on_message(self, msg_type: str): + def on_message(self, msg_type: str): """注册消息类型处理器""" + def decorator(func: Callable[[platform_events.Event], None]): if msg_type not in self._message_handlers: self._message_handlers[msg_type] = [] self._message_handlers[msg_type].append(func) return func + return decorator - async def send_message_to_channel(self,text:str,channel_id:str): - try: - response = await self.client.chat_postMessage( - channel=channel_id, - text=text - ) - if self.bot_user_id is None and response.get("ok"): - self.bot_user_id = response["message"]["bot_id"] - return - except Exception as e: - raise e + async def send_message_to_channel(self, text: str, channel_id: str): + try: + response = await self.client.chat_postMessage(channel=channel_id, text=text) + if self.bot_user_id is None and response.get('ok'): + self.bot_user_id = response['message']['bot_id'] + return + except Exception as e: + raise e - async def send_message_to_one(self,text:str,user_id:str): - try: - response = await self.client.chat_postMessage( - channel = '@'+user_id, - text= text - ) - if self.bot_user_id is None and response.get("ok"): - self.bot_user_id = response["message"]["bot_id"] - - return - except Exception as e: - raise e - - async def run_task(self, host: str, port: int, *args, **kwargs): - """ - 启动 Quart 应用。 - """ - await self.app.run_task(host=host, port=port, *args, **kwargs) - - - - - - + async def send_message_to_one(self, text: str, user_id: str): + try: + response = await self.client.chat_postMessage(channel='@' + user_id, text=text) + if self.bot_user_id is None and response.get('ok'): + self.bot_user_id = response['message']['bot_id'] + return + except Exception as e: + raise e + async def run_task(self, host: str, port: int, *args, **kwargs): + """ + 启动 Quart 应用。 + """ + await self.app.run_task(host=host, port=port, *args, **kwargs) diff --git a/libs/slack_api/slackevent.py b/libs/slack_api/slackevent.py index 5a6e9f90..77177816 100644 --- a/libs/slack_api/slackevent.py +++ b/libs/slack_api/slackevent.py @@ -1,86 +1,82 @@ from typing import Dict, Any, Optional + class SlackEvent(dict): @staticmethod - def from_payload(payload: Dict[str, Any]) -> Optional["SlackEvent"]: + def from_payload(payload: Dict[str, Any]) -> Optional['SlackEvent']: try: event = SlackEvent(payload) return event except KeyError: return None - + @property def text(self) -> str: - - if self.get("event", {}).get("channel_type") == "im": - blocks = self.get("event", {}).get("blocks", []) - if not blocks: - return "" + if self.get('event', {}).get('channel_type') == 'im': + blocks = self.get('event', {}).get('blocks', []) + if not blocks: + return '' - elements = blocks[0].get("elements", []) - if not elements: - return "" + elements = blocks[0].get('elements', []) + if not elements: + return '' - elements = elements[0].get("elements", []) - text = "" + elements = elements[0].get('elements', []) + text = '' - for el in elements: - if el.get("type") == "text": - text += el.get("text", "") - elif el.get("type") == "link": - text += el.get("url", "") + for el in elements: + if el.get('type') == 'text': + text += el.get('text', '') + elif el.get('type') == 'link': + text += el.get('url', '') - return text + return text - - if self.get("event",{}).get("channel_type") == 'channel': - message_text = "" - for block in self.get("event", {}).get("blocks", []): - if block.get("type") == "rich_text": - for element in block.get("elements", []): - if element.get("type") == "rich_text_section": + if self.get('event', {}).get('channel_type') == 'channel': + message_text = '' + for block in self.get('event', {}).get('blocks', []): + if block.get('type') == 'rich_text': + for element in block.get('elements', []): + if element.get('type') == 'rich_text_section': parts = [] - for el in element.get("elements", []): - if el.get("type") == "text": - parts.append(el["text"]) - elif el.get("type") == "link": - parts.append(el["url"]) - message_text = "".join(parts) - + for el in element.get('elements', []): + if el.get('type') == 'text': + parts.append(el['text']) + elif el.get('type') == 'link': + parts.append(el['url']) + message_text = ''.join(parts) + return message_text - - @property def user_id(self) -> Optional[str]: - return self.get("event", {}).get("user","") - + return self.get('event', {}).get('user', '') + @property def channel_id(self) -> Optional[str]: - return self.get("event", {}).get("channel","") - + return self.get('event', {}).get('channel', '') + @property def type(self) -> str: - """ message对应私聊,app_mention对应频道at """ - return self.get("event", {}).get("channel_type", "") - + """message对应私聊,app_mention对应频道at""" + return self.get('event', {}).get('channel_type', '') + @property def message_id(self) -> str: - return self.get("event_id","") - + return self.get('event_id', '') + @property def pic_url(self) -> str: """提取 Slack 事件中的图片 URL""" - files = self.get("event", {}).get("files", []) + files = self.get('event', {}).get('files', []) if files: - return files[0].get("url_private", "") + return files[0].get('url_private', '') return None - - + @property def sender_name(self) -> str: - return self.get("event", {}).get("user","") - + return self.get('event', {}).get('user', '') + def __getattr__(self, key: str) -> Optional[Any]: return self.get(key) @@ -88,4 +84,4 @@ class SlackEvent(dict): self[key] = value def __repr__(self) -> str: - return f"" + return f'' diff --git a/libs/wecom_api/WXBizMsgCrypt3.py b/libs/wecom_api/WXBizMsgCrypt3.py index a9a7bc89..ceb5e71a 100644 --- a/libs/wecom_api/WXBizMsgCrypt3.py +++ b/libs/wecom_api/WXBizMsgCrypt3.py @@ -147,12 +147,7 @@ class Prpcrypt(object): """ # 16位随机字符串添加到明文开头 text = text.encode() - text = ( - self.get_random_str() - + struct.pack('I', socket.htonl(len(text))) - + text - + receiveid.encode() - ) + text = self.get_random_str() + struct.pack('I', socket.htonl(len(text))) + text + receiveid.encode() # 使用自定义的填充方式对明文进行补位填充 pkcs7 = PKCS7Encoder() diff --git a/libs/wecom_api/api.py b/libs/wecom_api/api.py index 8993885b..f4a62be0 100644 --- a/libs/wecom_api/api.py +++ b/libs/wecom_api/api.py @@ -45,9 +45,7 @@ class WecomClient: return bool(self.access_token and self.access_token.strip()) async def check_access_token_for_contacts(self): - return bool( - self.access_token_for_contacts and self.access_token_for_contacts.strip() - ) + return bool(self.access_token_for_contacts and self.access_token_for_contacts.strip()) async def get_access_token(self, secret): url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corpid}&corpsecret={secret}' @@ -61,15 +59,9 @@ class WecomClient: async def get_users(self): if not self.check_access_token_for_contacts(): - self.access_token_for_contacts = await self.get_access_token( - self.secret_for_contacts - ) + self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts) - url = ( - self.base_url - + '/user/list_id?access_token=' - + self.access_token_for_contacts - ) + url = self.base_url + '/user/list_id?access_token=' + self.access_token_for_contacts async with httpx.AsyncClient() as client: params = { 'cursor': '', @@ -88,15 +80,9 @@ class WecomClient: async def send_to_all(self, content: str, agent_id: int): if not self.check_access_token_for_contacts(): - self.access_token_for_contacts = await self.get_access_token( - self.secret_for_contacts - ) + self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts) - url = ( - self.base_url - + '/message/send?access_token=' - + self.access_token_for_contacts - ) + url = self.base_url + '/message/send?access_token=' + self.access_token_for_contacts user_ids = await self.get_users() user_ids_string = '|'.join(user_ids) async with httpx.AsyncClient() as client: @@ -187,27 +173,21 @@ class WecomClient: if request.method == 'GET': echostr = request.args.get('echostr') - ret, reply_echo_str = self.wxcpt.VerifyURL( - msg_signature, timestamp, nonce, echostr - ) + ret, reply_echo_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr) if ret != 0: raise Exception(f'验证失败,错误码: {ret}') return reply_echo_str elif request.method == 'POST': encrypt_msg = await request.data - ret, xml_msg = self.wxcpt.DecryptMsg( - encrypt_msg, msg_signature, timestamp, nonce - ) + ret, xml_msg = self.wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce) if ret != 0: raise Exception(f'消息解密失败,错误码: {ret}') # 解析消息并处理 message_data = await self.get_message(xml_msg) if message_data: - event = WecomEvent.from_payload( - message_data - ) # 转换为 WecomEvent 对象 + event = WecomEvent.from_payload(message_data) # 转换为 WecomEvent 对象 if event: await self._handle_message(event) @@ -253,23 +233,13 @@ class WecomClient: 'FromUserName': root.find('FromUserName').text, 'CreateTime': int(root.find('CreateTime').text), 'MsgType': root.find('MsgType').text, - 'Content': root.find('Content').text - if root.find('Content') is not None - else None, - 'MsgId': int(root.find('MsgId').text) - if root.find('MsgId') is not None - else None, - 'AgentID': int(root.find('AgentID').text) - if root.find('AgentID') is not None - else None, + 'Content': root.find('Content').text if root.find('Content') is not None else None, + 'MsgId': int(root.find('MsgId').text) if root.find('MsgId') is not None else None, + 'AgentID': int(root.find('AgentID').text) if root.find('AgentID') is not None else None, } if message_data['MsgType'] == 'image': - message_data['MediaId'] = ( - root.find('MediaId').text if root.find('MediaId') is not None else None - ) - message_data['PicUrl'] = ( - root.find('PicUrl').text if root.find('PicUrl') is not None else None - ) + message_data['MediaId'] = root.find('MediaId').text if root.find('MediaId') is not None else None + message_data['PicUrl'] = root.find('PicUrl').text if root.find('PicUrl') is not None else None return message_data @@ -298,12 +268,7 @@ class WecomClient: if not await self.check_access_token(): self.access_token = await self.get_access_token(self.secret) - url = ( - self.base_url - + '/media/upload?access_token=' - + self.access_token - + '&type=file' - ) + url = self.base_url + '/media/upload?access_token=' + self.access_token + '&type=file' file_bytes = None file_name = 'uploaded_file.txt' diff --git a/libs/wecom_customer_service_api/api.py b/libs/wecom_customer_service_api/api.py index 04e5398b..965fefcd 100644 --- a/libs/wecom_customer_service_api/api.py +++ b/libs/wecom_customer_service_api/api.py @@ -6,60 +6,61 @@ import httpx import traceback from quart import Quart import xml.etree.ElementTree as ET -from typing import Callable, Dict, Any +from typing import Callable from .wecomcsevent import WecomCSEvent -from pkg.platform.types import events as platform_events, message as platform_message +from pkg.platform.types import message as platform_message import aiofiles -class WecomCSClient(): - def __init__(self,corpid:str,secret:str,token:str,EncodingAESKey:str): +class WecomCSClient: + def __init__(self, corpid: str, secret: str, token: str, EncodingAESKey: str): self.corpid = corpid self.secret = secret - self.access_token_for_contacts ='' + self.access_token_for_contacts = '' self.token = token self.aes = EncodingAESKey self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin' self.access_token = '' self.app = Quart(__name__) self.wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid) - self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']) + self.app.add_url_rule( + '/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'] + ) self._message_handlers = { - "example":[], + 'example': [], } async def get_pic_url(self, media_id: str): if not await self.check_access_token(): self.access_token = await self.get_access_token(self.secret) - url = f"{self.base_url}/media/get?access_token={self.access_token}&media_id={media_id}" + url = f'{self.base_url}/media/get?access_token={self.access_token}&media_id={media_id}' async with httpx.AsyncClient() as client: response = await client.get(url) - if response.headers.get("Content-Type", "").startswith("application/json"): + if response.headers.get('Content-Type', '').startswith('application/json'): data = response.json() if data.get('errcode') in [40014, 42001]: self.access_token = await self.get_access_token(self.secret) return await self.get_pic_url(media_id) else: - raise Exception("Failed to get image: " + str(data)) + raise Exception('Failed to get image: ' + str(data)) # 否则是图片,转成 base64 image_bytes = response.content - content_type = response.headers.get("Content-Type", "") - base64_str = base64.b64encode(image_bytes).decode("utf-8") - base64_str = f"data:{content_type};base64,{base64_str}" + content_type = response.headers.get('Content-Type', '') + base64_str = base64.b64encode(image_bytes).decode('utf-8') + base64_str = f'data:{content_type};base64,{base64_str}' return base64_str - - #access——token操作 + # access——token操作 async def check_access_token(self): return bool(self.access_token and self.access_token.strip()) async def check_access_token_for_contacts(self): return bool(self.access_token_for_contacts and self.access_token_for_contacts.strip()) - async def get_access_token(self,secret): + async def get_access_token(self, secret): url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corpid}&corpsecret={secret}' async with httpx.AsyncClient() as client: response = await client.get(url) @@ -67,118 +68,115 @@ class WecomCSClient(): if 'access_token' in data: return data['access_token'] else: - raise Exception(f"未获取access token: {data}") - - async def get_detailed_message_list(self,xml_msg:str): + raise Exception(f'未获取access token: {data}') + + async def get_detailed_message_list(self, xml_msg: str): # 在本方法中解析消息,并且获得消息的具体内容 if isinstance(xml_msg, bytes): xml_msg = xml_msg.decode('utf-8') root = ET.fromstring(xml_msg) - token = root.find("Token").text - open_kfid = root.find("OpenKfId").text - + token = root.find('Token').text + open_kfid = root.find('OpenKfId').text + # if open_kfid in self.openkfid_list: # return None # else: # self.openkfid_list.append(open_kfid) - + if not await self.check_access_token(): self.access_token = await self.get_access_token(self.secret) - - url = self.base_url+'/kf/sync_msg?access_token='+ self.access_token + + url = self.base_url + '/kf/sync_msg?access_token=' + self.access_token async with httpx.AsyncClient() as client: params = { - "token": token, - "voice_format": 0, - "open_kfid": open_kfid, + 'token': token, + 'voice_format': 0, + 'open_kfid': open_kfid, } - response = await client.post(url,json=params) + response = await client.post(url, json=params) data = response.json() if data['errcode'] == 40014 or data['errcode'] == 42001: self.access_token = await self.get_access_token(self.secret) return await self.get_detailed_message_list(xml_msg) if data['errcode'] != 0: - raise Exception("Failed to get message") - + raise Exception('Failed to get message') + last_msg_data = data['msg_list'][-1] - open_kfid = last_msg_data.get("open_kfid") + open_kfid = last_msg_data.get('open_kfid') # 进行获取图片操作 - if last_msg_data.get("msgtype") == "image": - media_id = last_msg_data.get("image").get("media_id") + if last_msg_data.get('msgtype') == 'image': + media_id = last_msg_data.get('image').get('media_id') picurl = await self.get_pic_url(media_id) - last_msg_data["picurl"] = picurl + last_msg_data['picurl'] = picurl # await self.change_service_status(userid=external_userid,openkfid=open_kfid,servicer=servicer) return last_msg_data - - - async def change_service_status(self,userid:str,openkfid:str,servicer:str): + + async def change_service_status(self, userid: str, openkfid: str, servicer: str): if not await self.check_access_token(): self.access_token = await self.get_access_token(self.secret) - url = self.base_url+"/kf/service_state/get?access_token="+self.access_token + url = self.base_url + '/kf/service_state/get?access_token=' + self.access_token async with httpx.AsyncClient() as client: params = { - "open_kfid" : openkfid, - "external_userid" : userid, - "service_state" : 1, - "servicer_userid" : servicer, + 'open_kfid': openkfid, + 'external_userid': userid, + 'service_state': 1, + 'servicer_userid': servicer, } - response = await client.post(url,json=params) + response = await client.post(url, json=params) data = response.json() if data['errcode'] == 40014 or data['errcode'] == 42001: self.access_token = await self.get_access_token(self.secret) - return await self.change_service_status(userid,openkfid) + return await self.change_service_status(userid, openkfid) if data['errcode'] != 0: - raise Exception("Failed to change service status: "+str(data)) - + raise Exception('Failed to change service status: ' + str(data)) - async def send_image(self,user_id:str,agent_id:int,media_id:str): + async def send_image(self, user_id: str, agent_id: int, media_id: str): if not await self.check_access_token(): self.access_token = await self.get_access_token(self.secret) - url = self.base_url+'/media/upload?access_token='+self.access_token + url = self.base_url + '/media/upload?access_token=' + self.access_token async with httpx.AsyncClient() as client: params = { - "touser" : user_id, - "toparty" : "", - "totag":"", - "agentid" : agent_id, - "msgtype" : "image", - "image" : { - "media_id" : media_id, + 'touser': user_id, + 'toparty': '', + 'totag': '', + 'agentid': agent_id, + 'msgtype': 'image', + 'image': { + 'media_id': media_id, }, - "safe":0, - "enable_id_trans": 0, - "enable_duplicate_check": 0, - "duplicate_check_interval": 1800 + 'safe': 0, + 'enable_id_trans': 0, + 'enable_duplicate_check': 0, + 'duplicate_check_interval': 1800, } try: - response = await client.post(url,json=params) + response = await client.post(url, json=params) data = response.json() except Exception as e: - raise Exception("Failed to send image: "+str(e)) + raise Exception('Failed to send image: ' + str(e)) # 企业微信错误码40014和42001,代表accesstoken问题 if data['errcode'] == 40014 or data['errcode'] == 42001: self.access_token = await self.get_access_token(self.secret) - return await self.send_image(user_id,agent_id,media_id) + return await self.send_image(user_id, agent_id, media_id) if data['errcode'] != 0: - raise Exception("Failed to send image: "+str(data)) - + raise Exception('Failed to send image: ' + str(data)) - async def send_text_msg(self, open_kfid: str, external_userid: str, msgid: str,content:str): + async def send_text_msg(self, open_kfid: str, external_userid: str, msgid: str, content: str): if not await self.check_access_token(): self.access_token = await self.get_access_token(self.secret) - url = f"https://qyapi.weixin.qq.com/cgi-bin/kf/send_msg?access_token={self.access_token}" + url = f'https://qyapi.weixin.qq.com/cgi-bin/kf/send_msg?access_token={self.access_token}' payload = { - "touser": external_userid, - "open_kfid": open_kfid, - "msgid": msgid, - "msgtype": "text", - "text": { - "content": content, - } + 'touser': external_userid, + 'open_kfid': open_kfid, + 'msgid': msgid, + 'msgtype': 'text', + 'text': { + 'content': content, + }, } async with httpx.AsyncClient() as client: @@ -187,46 +185,44 @@ class WecomCSClient(): data = response.json() if data['errcode'] == 40014 or data['errcode'] == 42001: self.access_token = await self.get_access_token(self.secret) - return await self.send_text_msg(open_kfid,external_userid,msgid,content) + return await self.send_text_msg(open_kfid, external_userid, msgid, content) if data['errcode'] != 0: - raise Exception("Failed to send message") + raise Exception('Failed to send message') return data - async def handle_callback_request(self): """ 处理回调请求,包括 GET 验证和 POST 消息接收。 """ try: + msg_signature = request.args.get('msg_signature') + timestamp = request.args.get('timestamp') + nonce = request.args.get('nonce') - msg_signature = request.args.get("msg_signature") - timestamp = request.args.get("timestamp") - nonce = request.args.get("nonce") - - if request.method == "GET": - echostr = request.args.get("echostr") + if request.method == 'GET': + echostr = request.args.get('echostr') ret, reply_echo_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr) if ret != 0: - raise Exception(f"验证失败,错误码: {ret}") + raise Exception(f'验证失败,错误码: {ret}') return reply_echo_str - elif request.method == "POST": + elif request.method == 'POST': encrypt_msg = await request.data ret, xml_msg = self.wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce) if ret != 0: - raise Exception(f"消息解密失败,错误码: {ret}") + raise Exception(f'消息解密失败,错误码: {ret}') # 解析消息并处理 message_data = await self.get_detailed_message_list(xml_msg) if message_data is not None: - event = WecomCSEvent.from_payload(message_data) + event = WecomCSEvent.from_payload(message_data) if event: await self._handle_message(event) - return "success" + return 'success' except Exception as e: traceback.print_exc() - return f"Error processing request: {str(e)}", 400 + return f'Error processing request: {str(e)}', 400 async def run_task(self, host: str, port: int, *args, **kwargs): """ @@ -238,11 +234,13 @@ class WecomCSClient(): """ 注册消息类型处理器。 """ + def decorator(func: Callable[[WecomCSEvent], None]): if msg_type not in self._message_handlers: self._message_handlers[msg_type] = [] self._message_handlers[msg_type].append(func) return func + return decorator async def _handle_message(self, event: WecomCSEvent): @@ -254,25 +252,23 @@ class WecomCSClient(): for handler in self._message_handlers[msg_type]: await handler(event) - @staticmethod async def get_image_type(image_bytes: bytes) -> str: """ 通过图片的magic numbers判断图片类型 """ magic_numbers = { - b'\xFF\xD8\xFF': 'jpg', - b'\x89\x50\x4E\x47': 'png', + b'\xff\xd8\xff': 'jpg', + b'\x89\x50\x4e\x47': 'png', b'\x47\x49\x46': 'gif', - b'\x42\x4D': 'bmp', - b'\x00\x00\x01\x00': 'ico' + b'\x42\x4d': 'bmp', + b'\x00\x00\x01\x00': 'ico', } - + for magic, ext in magic_numbers.items(): if image_bytes.startswith(magic): return ext return 'jpg' # 默认返回jpg - async def upload_to_work(self, image: platform_message.Image): """ @@ -283,7 +279,7 @@ class WecomCSClient(): url = self.base_url + '/media/upload?access_token=' + self.access_token + '&type=file' file_bytes = None - file_name = "uploaded_file.txt" + file_name = 'uploaded_file.txt' # 获取文件的二进制数据 if image.path: @@ -302,20 +298,22 @@ class WecomCSClient(): padded_base64 = base64_data + '=' * padding file_bytes = base64.b64decode(padded_base64) except binascii.Error as e: - raise ValueError(f"Invalid base64 string: {str(e)}") + raise ValueError(f'Invalid base64 string: {str(e)}') else: - raise ValueError("image对象出错") + raise ValueError('image对象出错') # 设置 multipart/form-data 格式的文件 - boundary = "-------------------------acebdf13572468" - headers = { - 'Content-Type': f'multipart/form-data; boundary={boundary}' - } + boundary = '-------------------------acebdf13572468' + headers = {'Content-Type': f'multipart/form-data; boundary={boundary}'} body = ( - f"--{boundary}\r\n" - f"Content-Disposition: form-data; name=\"media\"; filename=\"{file_name}\"; filelength={len(file_bytes)}\r\n" - f"Content-Type: application/octet-stream\r\n\r\n" - ).encode('utf-8') + file_bytes + f"\r\n--{boundary}--\r\n".encode('utf-8') + ( + f'--{boundary}\r\n' + f'Content-Disposition: form-data; name="media"; filename="{file_name}"; filelength={len(file_bytes)}\r\n' + f'Content-Type: application/octet-stream\r\n\r\n' + ).encode('utf-8') + + file_bytes + + f'\r\n--{boundary}--\r\n'.encode('utf-8') + ) # 上传文件 async with httpx.AsyncClient() as client: @@ -325,19 +323,18 @@ class WecomCSClient(): self.access_token = await self.get_access_token(self.secret) media_id = await self.upload_to_work(image) if data.get('errcode', 0) != 0: - raise Exception("failed to upload file") + raise Exception('failed to upload file') media_id = data.get('media_id') return media_id - async def download_image_to_bytes(self,url:str) -> bytes: + async def download_image_to_bytes(self, url: str) -> bytes: async with httpx.AsyncClient() as client: response = await client.get(url) response.raise_for_status() return response.content - #进行media_id的获取 + # 进行media_id的获取 async def get_media_id(self, image: platform_message.Image): - media_id = await self.upload_to_work(image=image) return media_id diff --git a/libs/wecom_customer_service_api/wecomcsevent.py b/libs/wecom_customer_service_api/wecomcsevent.py index 8dc0e30d..ee830a73 100644 --- a/libs/wecom_customer_service_api/wecomcsevent.py +++ b/libs/wecom_customer_service_api/wecomcsevent.py @@ -9,7 +9,7 @@ class WecomCSEvent(dict): """ @staticmethod - def from_payload(payload: Dict[str, Any]) -> Optional["WecomCSEvent"]: + def from_payload(payload: Dict[str, Any]) -> Optional['WecomCSEvent']: """ 从企业微信(客服会话)事件数据构造 `WecomEvent` 对象。 @@ -21,7 +21,7 @@ class WecomCSEvent(dict): """ try: event = WecomCSEvent(payload) - _ = event.type, + _ = (event.type,) return event except KeyError: return None @@ -34,8 +34,8 @@ class WecomCSEvent(dict): Returns: str: 事件类型。 """ - return self.get("msgtype", "") - + return self.get('msgtype', '') + @property def user_id(self) -> Optional[str]: """ @@ -44,7 +44,7 @@ class WecomCSEvent(dict): Returns: Optional[str]: 用户 ID。 """ - return self.get("external_userid") + return self.get('external_userid') @property def receiver_id(self) -> Optional[str]: @@ -54,8 +54,8 @@ class WecomCSEvent(dict): Returns: Optional[str]: 接收者 ID。 """ - return self.get("open_kfid","") - + return self.get('open_kfid', '') + @property def picurl(self) -> Optional[str]: """ @@ -65,7 +65,7 @@ class WecomCSEvent(dict): Optional[str]: 图片 URL。 """ - return self.get("picurl","") + return self.get('picurl', '') @property def message_id(self) -> Optional[str]: @@ -75,7 +75,7 @@ class WecomCSEvent(dict): Returns: Optional[str]: 消息 ID。 """ - return self.get("msgid") + return self.get('msgid') @property def message(self) -> Optional[str]: @@ -85,12 +85,11 @@ class WecomCSEvent(dict): Returns: Optional[str]: 消息内容。 """ - if self.get("msgtype") == 'text': - return self.get("text").get("content","") + if self.get('msgtype') == 'text': + return self.get('text').get('content', '') else: return None - @property def timestamp(self) -> Optional[int]: """ @@ -99,8 +98,7 @@ class WecomCSEvent(dict): Returns: Optional[int]: 时间戳。 """ - return self.get("send_time") - + return self.get('send_time') def __getattr__(self, key: str) -> Optional[Any]: """ @@ -131,4 +129,4 @@ class WecomCSEvent(dict): Returns: str: 字符串表示。 """ - return f"" + return f'' diff --git a/pkg/api/http/controller/group.py b/pkg/api/http/controller/group.py index 343ec690..ce366539 100644 --- a/pkg/api/http/controller/group.py +++ b/pkg/api/http/controller/group.py @@ -65,9 +65,7 @@ class RouterGroup(abc.ABC): async def handler_error(*args, **kwargs): if auth_type == AuthType.USER_TOKEN: # 从Authorization头中获取token - token = quart.request.headers.get('Authorization', '').replace( - 'Bearer ', '' - ) + token = quart.request.headers.get('Authorization', '').replace('Bearer ', '') if not token: return self.http_status(401, -1, '未提供有效的用户令牌') diff --git a/pkg/api/http/controller/groups/logs.py b/pkg/api/http/controller/groups/logs.py index b0643cb6..e3bff9db 100644 --- a/pkg/api/http/controller/groups/logs.py +++ b/pkg/api/http/controller/groups/logs.py @@ -14,10 +14,8 @@ class LogsRouterGroup(group.RouterGroup): start_page_number = int(quart.request.args.get('start_page_number', 0)) start_offset = int(quart.request.args.get('start_offset', 0)) - logs_str, end_page_number, end_offset = ( - self.ap.log_cache.get_log_by_pointer( - start_page_number=start_page_number, start_offset=start_offset - ) + logs_str, end_page_number, end_offset = self.ap.log_cache.get_log_by_pointer( + start_page_number=start_page_number, start_offset=start_offset ) return self.success( diff --git a/pkg/api/http/controller/groups/pipelines.py b/pkg/api/http/controller/groups/pipelines.py index 02564e58..1a8036cc 100644 --- a/pkg/api/http/controller/groups/pipelines.py +++ b/pkg/api/http/controller/groups/pipelines.py @@ -11,23 +11,17 @@ class PipelinesRouterGroup(group.RouterGroup): @self.route('', methods=['GET', 'POST']) async def _() -> str: if quart.request.method == 'GET': - return self.success( - data={'pipelines': await self.ap.pipeline_service.get_pipelines()} - ) + return self.success(data={'pipelines': await self.ap.pipeline_service.get_pipelines()}) elif quart.request.method == 'POST': json_data = await quart.request.json - pipeline_uuid = await self.ap.pipeline_service.create_pipeline( - json_data - ) + pipeline_uuid = await self.ap.pipeline_service.create_pipeline(json_data) return self.success(data={'uuid': pipeline_uuid}) @self.route('/_/metadata', methods=['GET']) async def _() -> str: - return self.success( - data={'configs': await self.ap.pipeline_service.get_pipeline_metadata()} - ) + return self.success(data={'configs': await self.ap.pipeline_service.get_pipeline_metadata()}) @self.route('/', methods=['GET', 'PUT', 'DELETE']) async def _(pipeline_uuid: str) -> str: diff --git a/pkg/api/http/controller/groups/platform/adapters.py b/pkg/api/http/controller/groups/platform/adapters.py index f6fb8489..4136791c 100644 --- a/pkg/api/http/controller/groups/platform/adapters.py +++ b/pkg/api/http/controller/groups/platform/adapters.py @@ -8,30 +8,20 @@ class AdaptersRouterGroup(group.RouterGroup): async def initialize(self) -> None: @self.route('', methods=['GET']) async def _() -> str: - return self.success( - data={'adapters': self.ap.platform_mgr.get_available_adapters_info()} - ) + return self.success(data={'adapters': self.ap.platform_mgr.get_available_adapters_info()}) @self.route('/', methods=['GET']) async def _(adapter_name: str) -> str: - adapter_info = self.ap.platform_mgr.get_available_adapter_info_by_name( - adapter_name - ) + adapter_info = self.ap.platform_mgr.get_available_adapter_info_by_name(adapter_name) if adapter_info is None: return self.http_status(404, -1, 'adapter not found') return self.success(data={'adapter': adapter_info}) - @self.route( - '//icon', methods=['GET'], auth_type=group.AuthType.NONE - ) + @self.route('//icon', methods=['GET'], auth_type=group.AuthType.NONE) async def _(adapter_name: str) -> quart.Response: - adapter_manifest = ( - self.ap.platform_mgr.get_available_adapter_manifest_by_name( - adapter_name - ) - ) + adapter_manifest = self.ap.platform_mgr.get_available_adapter_manifest_by_name(adapter_name) if adapter_manifest is None: return self.http_status(404, -1, 'adapter not found') diff --git a/pkg/api/http/controller/groups/plugins.py b/pkg/api/http/controller/groups/plugins.py index 1deecca6..daf6ea7d 100644 --- a/pkg/api/http/controller/groups/plugins.py +++ b/pkg/api/http/controller/groups/plugins.py @@ -92,9 +92,7 @@ class PluginsRouterGroup(group.RouterGroup): await self.ap.plugin_mgr.reorder_plugins(data.get('plugins')) return self.success() - @self.route( - '/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN - ) + @self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) async def _() -> str: data = await quart.request.json diff --git a/pkg/api/http/controller/groups/provider/models.py b/pkg/api/http/controller/groups/provider/models.py index b3680c4f..683fac01 100644 --- a/pkg/api/http/controller/groups/provider/models.py +++ b/pkg/api/http/controller/groups/provider/models.py @@ -9,9 +9,7 @@ class LLMModelsRouterGroup(group.RouterGroup): @self.route('', methods=['GET', 'POST']) async def _() -> str: if quart.request.method == 'GET': - return self.success( - data={'models': await self.ap.model_service.get_llm_models()} - ) + return self.success(data={'models': await self.ap.model_service.get_llm_models()}) elif quart.request.method == 'POST': json_data = await quart.request.json diff --git a/pkg/api/http/controller/groups/provider/requesters.py b/pkg/api/http/controller/groups/provider/requesters.py index 5d93b9cb..0f999288 100644 --- a/pkg/api/http/controller/groups/provider/requesters.py +++ b/pkg/api/http/controller/groups/provider/requesters.py @@ -8,30 +8,20 @@ class RequestersRouterGroup(group.RouterGroup): async def initialize(self) -> None: @self.route('', methods=['GET']) async def _() -> quart.Response: - return self.success( - data={'requesters': self.ap.model_mgr.get_available_requesters_info()} - ) + return self.success(data={'requesters': self.ap.model_mgr.get_available_requesters_info()}) @self.route('/', methods=['GET']) async def _(requester_name: str) -> quart.Response: - requester_info = self.ap.model_mgr.get_available_requester_info_by_name( - requester_name - ) + requester_info = self.ap.model_mgr.get_available_requester_info_by_name(requester_name) if requester_info is None: return self.http_status(404, -1, 'requester not found') return self.success(data={'requester': requester_info}) - @self.route( - '//icon', methods=['GET'], auth_type=group.AuthType.NONE - ) + @self.route('//icon', methods=['GET'], auth_type=group.AuthType.NONE) async def _(requester_name: str) -> quart.Response: - requester_manifest = ( - self.ap.model_mgr.get_available_requester_manifest_by_name( - requester_name - ) - ) + requester_manifest = self.ap.model_mgr.get_available_requester_manifest_by_name(requester_name) if requester_manifest is None: return self.http_status(404, -1, 'requester not found') diff --git a/pkg/api/http/controller/groups/stats.py b/pkg/api/http/controller/groups/stats.py index 7b1d4353..8c8e9113 100644 --- a/pkg/api/http/controller/groups/stats.py +++ b/pkg/api/http/controller/groups/stats.py @@ -8,9 +8,7 @@ class StatsRouterGroup(group.RouterGroup): async def _() -> str: conv_count = 0 for session in self.ap.sess_mgr.session_list: - conv_count += len( - session.conversations if session.conversations is not None else [] - ) + conv_count += len(session.conversations if session.conversations is not None else []) return self.success( data={ diff --git a/pkg/api/http/controller/groups/system.py b/pkg/api/http/controller/groups/system.py index c586ea27..c4cab602 100644 --- a/pkg/api/http/controller/groups/system.py +++ b/pkg/api/http/controller/groups/system.py @@ -13,9 +13,7 @@ class SystemRouterGroup(group.RouterGroup): data={ 'version': constants.semantic_version, 'debug': constants.debug_mode, - 'enabled_platform_count': len( - self.ap.platform_mgr.get_running_adapters() - ), + 'enabled_platform_count': len(self.ap.platform_mgr.get_running_adapters()), } ) @@ -28,9 +26,7 @@ class SystemRouterGroup(group.RouterGroup): return self.success(data=self.ap.task_mgr.get_tasks_dict(task_type)) - @self.route( - '/tasks/', methods=['GET'], auth_type=group.AuthType.USER_TOKEN - ) + @self.route('/tasks/', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) async def _(task_id: str) -> str: task = self.ap.task_mgr.get_task_by_id(int(task_id)) @@ -48,9 +44,7 @@ class SystemRouterGroup(group.RouterGroup): await self.ap.reload(scope=scope) return self.success() - @self.route( - '/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN - ) + @self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) async def _() -> str: if not constants.debug_mode: return self.http_status(403, 403, 'Forbidden') diff --git a/pkg/api/http/controller/groups/user.py b/pkg/api/http/controller/groups/user.py index 4c330782..498efaa4 100644 --- a/pkg/api/http/controller/groups/user.py +++ b/pkg/api/http/controller/groups/user.py @@ -10,9 +10,7 @@ class UserRouterGroup(group.RouterGroup): @self.route('/init', methods=['GET', 'POST'], auth_type=group.AuthType.NONE) async def _() -> str: if quart.request.method == 'GET': - return self.success( - data={'initialized': await self.ap.user_service.is_initialized()} - ) + return self.success(data={'initialized': await self.ap.user_service.is_initialized()}) if await self.ap.user_service.is_initialized(): return self.fail(1, '系统已初始化') @@ -31,17 +29,13 @@ class UserRouterGroup(group.RouterGroup): json_data = await quart.request.json try: - token = await self.ap.user_service.authenticate( - json_data['user'], json_data['password'] - ) + token = await self.ap.user_service.authenticate(json_data['user'], json_data['password']) except argon2.exceptions.VerifyMismatchError: return self.fail(1, '用户名或密码错误') return self.success(data={'token': token}) - @self.route( - '/check-token', methods=['GET'], auth_type=group.AuthType.USER_TOKEN - ) + @self.route('/check-token', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) async def _(user_email: str) -> str: token = await self.ap.user_service.generate_jwt_token(user_email) diff --git a/pkg/api/http/controller/main.py b/pkg/api/http/controller/main.py index 0c8e2f70..3c8097b8 100644 --- a/pkg/api/http/controller/main.py +++ b/pkg/api/http/controller/main.py @@ -70,15 +70,12 @@ class HTTPController: @self.quart_app.route('/') async def index(): - return await quart.send_from_directory( - frontend_path, 'index.html', mimetype='text/html' - ) + return await quart.send_from_directory(frontend_path, 'index.html', mimetype='text/html') @self.quart_app.route('/') async def static_file(path: str): if not ( - os.path.exists(os.path.join(frontend_path, path)) - and os.path.isfile(os.path.join(frontend_path, path)) + os.path.exists(os.path.join(frontend_path, path)) and os.path.isfile(os.path.join(frontend_path, path)) ): if os.path.exists(os.path.join(frontend_path, path + '.html')): path += '.html' @@ -110,6 +107,4 @@ class HTTPController: elif path.endswith('.txt'): mimetype = 'text/plain' - return await quart.send_from_directory( - frontend_path, path, mimetype=mimetype - ) + return await quart.send_from_directory(frontend_path, path, mimetype=mimetype) diff --git a/pkg/api/http/service/bot.py b/pkg/api/http/service/bot.py index 23e9fa5b..e562a310 100644 --- a/pkg/api/http/service/bot.py +++ b/pkg/api/http/service/bot.py @@ -18,23 +18,16 @@ class BotService: async def get_bots(self) -> list[dict]: """获取所有机器人""" - result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_bot.Bot) - ) + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot)) bots = result.all() - return [ - self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot) - for bot in bots - ] + return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot) for bot in bots] async def get_bot(self, bot_uuid: str) -> dict | None: """获取机器人""" result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_bot.Bot).where( - persistence_bot.Bot.uuid == bot_uuid - ) + sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid) ) bot = result.first() @@ -60,9 +53,7 @@ class BotService: bot_data['use_pipeline_uuid'] = pipeline.uuid bot_data['use_pipeline_name'] = pipeline.name - await self.ap.persistence_mgr.execute_async( - sqlalchemy.insert(persistence_bot.Bot).values(bot_data) - ) + await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_bot.Bot).values(bot_data)) bot = await self.get_bot(bot_data['uuid']) @@ -79,8 +70,7 @@ class BotService: if 'use_pipeline_uuid' in bot_data: result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_pipeline.LegacyPipeline).where( - persistence_pipeline.LegacyPipeline.uuid - == bot_data['use_pipeline_uuid'] + persistence_pipeline.LegacyPipeline.uuid == bot_data['use_pipeline_uuid'] ) ) pipeline = result.first() @@ -90,9 +80,7 @@ class BotService: raise Exception('Pipeline not found') await self.ap.persistence_mgr.execute_async( - sqlalchemy.update(persistence_bot.Bot) - .values(bot_data) - .where(persistence_bot.Bot.uuid == bot_uuid) + sqlalchemy.update(persistence_bot.Bot).values(bot_data).where(persistence_bot.Bot.uuid == bot_uuid) ) await self.ap.platform_mgr.remove_bot(bot_uuid) @@ -108,7 +96,5 @@ class BotService: """删除机器人""" await self.ap.platform_mgr.remove_bot(bot_uuid) await self.ap.persistence_mgr.execute_async( - sqlalchemy.delete(persistence_bot.Bot).where( - persistence_bot.Bot.uuid == bot_uuid - ) + sqlalchemy.delete(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid) ) diff --git a/pkg/api/http/service/model.py b/pkg/api/http/service/model.py index 8a71bf1c..080abb9d 100644 --- a/pkg/api/http/service/model.py +++ b/pkg/api/http/service/model.py @@ -15,22 +15,15 @@ class ModelsService: self.ap = ap async def get_llm_models(self) -> list[dict]: - result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_model.LLMModel) - ) + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel)) models = result.all() - return [ - self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) - for model in models - ] + return [self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) for model in models] async def create_llm_model(self, model_data: dict) -> str: model_data['uuid'] = str(uuid.uuid4()) - await self.ap.persistence_mgr.execute_async( - sqlalchemy.insert(persistence_model.LLMModel).values(**model_data) - ) + await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_model.LLMModel).values(**model_data)) llm_model = await self.get_llm_model(model_data['uuid']) @@ -53,9 +46,7 @@ class ModelsService: async def get_llm_model(self, model_uuid: str) -> dict | None: result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_model.LLMModel).where( - persistence_model.LLMModel.uuid == model_uuid - ) + sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid) ) model = result.first() @@ -63,9 +54,7 @@ class ModelsService: if model is None: return None - return self.ap.persistence_mgr.serialize_model( - persistence_model.LLMModel, model - ) + return self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) async def update_llm_model(self, model_uuid: str, model_data: dict) -> None: if 'uuid' in model_data: @@ -85,9 +74,7 @@ class ModelsService: async def delete_llm_model(self, model_uuid: str) -> None: await self.ap.persistence_mgr.execute_async( - sqlalchemy.delete(persistence_model.LLMModel).where( - persistence_model.LLMModel.uuid == model_uuid - ) + sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid) ) await self.ap.model_mgr.remove_llm_model(model_uuid) diff --git a/pkg/api/http/service/pipeline.py b/pkg/api/http/service/pipeline.py index 0dd73ef2..ee648db0 100644 --- a/pkg/api/http/service/pipeline.py +++ b/pkg/api/http/service/pipeline.py @@ -39,15 +39,11 @@ class PipelineService: ] async def get_pipelines(self) -> list[dict]: - result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_pipeline.LegacyPipeline) - ) + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline)) pipelines = result.all() return [ - self.ap.persistence_mgr.serialize_model( - persistence_pipeline.LegacyPipeline, pipeline - ) + self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline) for pipeline in pipelines ] @@ -63,23 +59,17 @@ class PipelineService: if pipeline is None: return None - return self.ap.persistence_mgr.serialize_model( - persistence_pipeline.LegacyPipeline, pipeline - ) + return self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline) async def create_pipeline(self, pipeline_data: dict, default: bool = False) -> str: pipeline_data['uuid'] = str(uuid.uuid4()) pipeline_data['for_version'] = self.ap.ver_mgr.get_current_version() pipeline_data['stages'] = default_stage_order.copy() pipeline_data['is_default'] = default - pipeline_data['config'] = json.load( - open('templates/default-pipeline-config.json', 'r', encoding='utf-8') - ) + pipeline_data['config'] = json.load(open('templates/default-pipeline-config.json', 'r', encoding='utf-8')) await self.ap.persistence_mgr.execute_async( - sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values( - **pipeline_data - ) + sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(**pipeline_data) ) pipeline = await self.get_pipeline(pipeline_data['uuid']) diff --git a/pkg/api/http/service/user.py b/pkg/api/http/service/user.py index 03a64576..782aad75 100644 --- a/pkg/api/http/service/user.py +++ b/pkg/api/http/service/user.py @@ -17,9 +17,7 @@ class UserService: self.ap = ap async def is_initialized(self) -> bool: - result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(user.User).limit(1) - ) + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(user.User).limit(1)) result_list = result.all() return result_list is not None and len(result_list) > 0 @@ -30,9 +28,7 @@ class UserService: hashed_password = ph.hash(password) await self.ap.persistence_mgr.execute_async( - sqlalchemy.insert(user.User).values( - user=user_email, password=hashed_password - ) + sqlalchemy.insert(user.User).values(user=user_email, password=hashed_password) ) async def get_user_by_email(self, user_email: str) -> user.User | None: @@ -41,9 +37,7 @@ class UserService: ) result_list = result.all() - return ( - result_list[0] if result_list is not None and len(result_list) > 0 else None - ) + return result_list[0] if result_list is not None and len(result_list) > 0 else None async def authenticate(self, user_email: str, password: str) -> str | None: result = await self.ap.persistence_mgr.execute_async( diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py index 10d76067..1bd03fcf 100644 --- a/pkg/command/cmdmgr.py +++ b/pkg/command/cmdmgr.py @@ -40,18 +40,14 @@ class CommandManager: # 应用命令权限配置 for cls in operator.preregistered_operators: if cls.path in self.ap.instance_config.data['command']['privilege']: - cls.lowest_privilege = self.ap.instance_config.data['command'][ - 'privilege' - ][cls.path] + cls.lowest_privilege = self.ap.instance_config.data['command']['privilege'][cls.path] # 实例化所有类 self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators] # 设置所有类的子节点 for cmd in self.cmd_list: - cmd.children = [ - child for child in self.cmd_list if child.parent_class == cmd.__class__ - ] + cmd.children = [child for child in self.cmd_list if child.parent_class == cmd.__class__] # 初始化所有类 for cmd in self.cmd_list: @@ -68,10 +64,7 @@ class CommandManager: found = False if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名 for oper in operator_list: - if ( - context.crt_params[0] == oper.name - or context.crt_params[0] in oper.alias - ) and ( + if (context.crt_params[0] == oper.name or context.crt_params[0] in oper.alias) and ( oper.parent_class is None or oper.parent_class == operator.__class__ ): found = True @@ -85,14 +78,10 @@ class CommandManager: if not found: # 如果下一个参数未在此节点的子节点中找到,则执行此节点或者报错 if operator is None: - yield entities.CommandReturn( - error=errors.CommandNotFoundError(context.crt_params[0]) - ) + yield entities.CommandReturn(error=errors.CommandNotFoundError(context.crt_params[0])) else: if operator.lowest_privilege > context.privilege: - yield entities.CommandReturn( - error=errors.CommandPrivilegeError(operator.name) - ) + yield entities.CommandReturn(error=errors.CommandPrivilegeError(operator.name)) else: async for ret in operator.execute(context): yield ret @@ -107,10 +96,7 @@ class CommandManager: privilege = 1 - if ( - f'{query.launcher_type.value}_{query.launcher_id}' - in self.ap.instance_config.data['admins'] - ): + if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']: privilege = 2 ctx = entities.ExecuteContext( diff --git a/pkg/command/operator.py b/pkg/command/operator.py index 7072edf7..9ee3de37 100644 --- a/pkg/command/operator.py +++ b/pkg/command/operator.py @@ -95,9 +95,7 @@ class CommandOperator(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: """实现此方法以执行命令 支持多次yield以返回多个结果。 diff --git a/pkg/command/operators/cmd.py b/pkg/command/operators/cmd.py index a13d5b35..f5a69a7b 100644 --- a/pkg/command/operators/cmd.py +++ b/pkg/command/operators/cmd.py @@ -9,9 +9,7 @@ from .. import operator, entities, errors class CmdOperator(operator.CommandOperator): """命令列表""" - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: """执行""" if len(context.crt_params) == 0: reply_str = '当前所有命令: \n\n' @@ -30,16 +28,12 @@ class CmdOperator(operator.CommandOperator): cmd = None for _cmd in self.ap.cmd_mgr.cmd_list: - if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and ( - _cmd.parent_class is None - ): + if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (_cmd.parent_class is None): cmd = _cmd break if cmd is None: - yield entities.CommandReturn( - error=errors.CommandNotFoundError(cmd_name) - ) + yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name)) else: reply_str = f'{cmd.name}: {cmd.help}\n\n' reply_str += f'使用方法: \n{cmd.usage}' diff --git a/pkg/command/operators/delc.py b/pkg/command/operators/delc.py index 9ae507f5..7e72ff3c 100644 --- a/pkg/command/operators/delc.py +++ b/pkg/command/operators/delc.py @@ -5,55 +5,38 @@ import typing from .. import operator, entities, errors -@operator.operator_class( - name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all' -) +@operator.operator_class(name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all') class DelOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: if context.session.conversations: delete_index = 0 if len(context.crt_params) > 0: try: delete_index = int(context.crt_params[0]) except Exception: - yield entities.CommandReturn( - error=errors.CommandOperationError('索引必须是整数') - ) + yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数')) return if delete_index < 0 or delete_index >= len(context.session.conversations): - yield entities.CommandReturn( - error=errors.CommandOperationError('索引超出范围') - ) + yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围')) return # 倒序 to_delete_index = len(context.session.conversations) - 1 - delete_index - if ( - context.session.conversations[to_delete_index] - == context.session.using_conversation - ): + if context.session.conversations[to_delete_index] == context.session.using_conversation: context.session.using_conversation = None del context.session.conversations[to_delete_index] yield entities.CommandReturn(text=f'已删除对话: {delete_index}') else: - yield entities.CommandReturn( - error=errors.CommandOperationError('当前没有对话') - ) + yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) -@operator.operator_class( - name='all', help='删除此会话的所有历史记录', parent_class=DelOperator -) +@operator.operator_class(name='all', help='删除此会话的所有历史记录', parent_class=DelOperator) class DelAllOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: context.session.conversations = [] context.session.using_conversation = None diff --git a/pkg/command/operators/func.py b/pkg/command/operators/func.py index 9cb3fd32..648cc5e2 100644 --- a/pkg/command/operators/func.py +++ b/pkg/command/operators/func.py @@ -6,9 +6,7 @@ from .. import operator, entities @operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func') class FuncOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> AsyncGenerator[entities.CommandReturn, None]: reply_str = '当前已启用的内容函数: \n\n' index = 1 diff --git a/pkg/command/operators/help.py b/pkg/command/operators/help.py index c718d4b9..91ad66dc 100644 --- a/pkg/command/operators/help.py +++ b/pkg/command/operators/help.py @@ -7,9 +7,7 @@ from .. import operator, entities @operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>') class HelpOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接:https://langbot.app' help += '\n发送命令 !cmd 可查看命令列表' diff --git a/pkg/command/operators/last.py b/pkg/command/operators/last.py index 7e2f2453..25b1fc6a 100644 --- a/pkg/command/operators/last.py +++ b/pkg/command/operators/last.py @@ -8,36 +8,21 @@ from .. import operator, entities, errors @operator.operator_class(name='last', help='切换到前一个对话', usage='!last') class LastOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: if context.session.conversations: # 找到当前会话的上一个会话 for index in range(len(context.session.conversations) - 1, -1, -1): - if ( - context.session.conversations[index] - == context.session.using_conversation - ): + if context.session.conversations[index] == context.session.using_conversation: if index == 0: - yield entities.CommandReturn( - error=errors.CommandOperationError('已经是第一个对话了') - ) + yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了')) return else: - context.session.using_conversation = ( - context.session.conversations[index - 1] - ) - time_str = ( - context.session.using_conversation.create_time.strftime( - '%Y-%m-%d %H:%M:%S' - ) - ) + context.session.using_conversation = context.session.conversations[index - 1] + time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S') yield entities.CommandReturn( text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}' ) return else: - yield entities.CommandReturn( - error=errors.CommandOperationError('当前没有对话') - ) + yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) diff --git a/pkg/command/operators/list.py b/pkg/command/operators/list.py index 1aa63c94..70ff3945 100644 --- a/pkg/command/operators/list.py +++ b/pkg/command/operators/list.py @@ -5,22 +5,16 @@ import typing from .. import operator, entities, errors -@operator.operator_class( - name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>' -) +@operator.operator_class(name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>') class ListOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: page = 0 if len(context.crt_params) > 0: try: page = int(context.crt_params[0] - 1) except Exception: - yield entities.CommandReturn( - error=errors.CommandOperationError('页码应为整数') - ) + yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数')) return record_per_page = 10 @@ -38,7 +32,9 @@ class ListOperator(operator.CommandOperator): using_conv_index = index if index >= page * record_per_page and index < (page + 1) * record_per_page: - content += f'{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else "无内容"}\n' + content += ( + f'{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else "无内容"}\n' + ) index += 1 if content == '': diff --git a/pkg/command/operators/model.py b/pkg/command/operators/model.py index cc3ef5b9..07b7c0cd 100644 --- a/pkg/command/operators/model.py +++ b/pkg/command/operators/model.py @@ -14,9 +14,7 @@ from .. import operator, entities, errors class ModelOperator(operator.CommandOperator): """Model命令""" - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: content = '模型列表:\n' model_list = self.ap.model_mgr.model_list @@ -31,15 +29,11 @@ class ModelOperator(operator.CommandOperator): yield entities.CommandReturn(text=content.strip()) -@operator.operator_class( - name='show', help='显示模型详情', privilege=2, parent_class=ModelOperator -) +@operator.operator_class(name='show', help='显示模型详情', privilege=2, parent_class=ModelOperator) class ModelShowOperator(operator.CommandOperator): """Model Show命令""" - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: model_name = context.crt_params[0] model = None @@ -49,9 +43,7 @@ class ModelShowOperator(operator.CommandOperator): break if model is None: - yield entities.CommandReturn( - error=errors.CommandError(f'未找到模型 {model_name}') - ) + yield entities.CommandReturn(error=errors.CommandError(f'未找到模型 {model_name}')) else: content = '模型详情\n' content += f'名称: {model.name}\n' @@ -65,15 +57,11 @@ class ModelShowOperator(operator.CommandOperator): yield entities.CommandReturn(text=content.strip()) -@operator.operator_class( - name='set', help='设置默认使用模型', privilege=2, parent_class=ModelOperator -) +@operator.operator_class(name='set', help='设置默认使用模型', privilege=2, parent_class=ModelOperator) class ModelSetOperator(operator.CommandOperator): """Model Set命令""" - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: model_name = context.crt_params[0] model = None @@ -83,12 +71,8 @@ class ModelSetOperator(operator.CommandOperator): break if model is None: - yield entities.CommandReturn( - error=errors.CommandError(f'未找到模型 {model_name}') - ) + yield entities.CommandReturn(error=errors.CommandError(f'未找到模型 {model_name}')) else: self.ap.provider_cfg.data['model'] = model_name await self.ap.provider_cfg.dump_config() - yield entities.CommandReturn( - text=f'已设置当前使用模型为 {model_name},重置会话以生效' - ) + yield entities.CommandReturn(text=f'已设置当前使用模型为 {model_name},重置会话以生效') diff --git a/pkg/command/operators/next.py b/pkg/command/operators/next.py index ef5ae103..938c8331 100644 --- a/pkg/command/operators/next.py +++ b/pkg/command/operators/next.py @@ -7,36 +7,21 @@ from .. import operator, entities, errors @operator.operator_class(name='next', help='切换到后一个对话', usage='!next') class NextOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: if context.session.conversations: # 找到当前会话的下一个会话 for index in range(len(context.session.conversations)): - if ( - context.session.conversations[index] - == context.session.using_conversation - ): + if context.session.conversations[index] == context.session.using_conversation: if index == len(context.session.conversations) - 1: - yield entities.CommandReturn( - error=errors.CommandOperationError('已经是最后一个对话了') - ) + yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了')) return else: - context.session.using_conversation = ( - context.session.conversations[index + 1] - ) - time_str = ( - context.session.using_conversation.create_time.strftime( - '%Y-%m-%d %H:%M:%S' - ) - ) + context.session.using_conversation = context.session.conversations[index + 1] + time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S') yield entities.CommandReturn( text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}' ) return else: - yield entities.CommandReturn( - error=errors.CommandOperationError('当前没有对话') - ) + yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) diff --git a/pkg/command/operators/ollama.py b/pkg/command/operators/ollama.py index 7e65d440..93061f7d 100644 --- a/pkg/command/operators/ollama.py +++ b/pkg/command/operators/ollama.py @@ -13,9 +13,7 @@ from .. import operator, entities, errors usage='!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>', ) class OllamaOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: try: content: str = '模型列表:\n' model_list: list = ollama.list().get('models', []) @@ -25,9 +23,7 @@ class OllamaOperator(operator.CommandOperator): content += f'大小: {bytes_to_mb(model["size"])}MB\n\n' yield entities.CommandReturn(text=f'{content.strip()}') except ollama.ResponseError: - yield entities.CommandReturn( - error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常') - ) + yield entities.CommandReturn(error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常')) def bytes_to_mb(num_bytes): @@ -35,13 +31,9 @@ def bytes_to_mb(num_bytes): return format(mb, '.2f') -@operator.operator_class( - name='show', help='ollama模型详情', privilege=2, parent_class=OllamaOperator -) +@operator.operator_class(name='show', help='ollama模型详情', privilege=2, parent_class=OllamaOperator) class OllamaShowOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: content: str = '模型详情:\n' try: show: dict = ollama.show(model=context.crt_params[0]) @@ -60,27 +52,19 @@ class OllamaShowOperator(operator.CommandOperator): content += json.dumps(show, indent=4) yield entities.CommandReturn(text=content.strip()) except ollama.ResponseError: - yield entities.CommandReturn( - error=errors.CommandError('无法获取模型详情,请确认 Ollama 服务正常') - ) + yield entities.CommandReturn(error=errors.CommandError('无法获取模型详情,请确认 Ollama 服务正常')) -@operator.operator_class( - name='pull', help='ollama模型拉取', privilege=2, parent_class=OllamaOperator -) +@operator.operator_class(name='pull', help='ollama模型拉取', privilege=2, parent_class=OllamaOperator) class OllamaPullOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: try: model_list: list = ollama.list().get('models', []) if context.crt_params[0] in [model['name'] for model in model_list]: yield entities.CommandReturn(text='模型已存在') return except ollama.ResponseError: - yield entities.CommandReturn( - error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常') - ) + yield entities.CommandReturn(error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常')) return on_progress: bool = False @@ -108,13 +92,9 @@ class OllamaPullOperator(operator.CommandOperator): yield entities.CommandReturn(text=f'拉取失败: {e.error}') -@operator.operator_class( - name='del', help='ollama模型删除', privilege=2, parent_class=OllamaOperator -) +@operator.operator_class(name='del', help='ollama模型删除', privilege=2, parent_class=OllamaOperator) class OllamaDelOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: try: ret: str = ollama.delete(model=context.crt_params[0])['status'] except ollama.ResponseError as e: diff --git a/pkg/command/operators/plugin.py b/pkg/command/operators/plugin.py index 1bf4c7af..40ec0e3a 100644 --- a/pkg/command/operators/plugin.py +++ b/pkg/command/operators/plugin.py @@ -11,9 +11,7 @@ from .. import operator, entities, errors usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>', ) class PluginOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: plugin_list = self.ap.plugin_mgr.plugins() reply_str = '所有插件({}):\n'.format(len(plugin_list)) idx = 0 @@ -32,17 +30,11 @@ class PluginOperator(operator.CommandOperator): yield entities.CommandReturn(text=reply_str) -@operator.operator_class( - name='get', help='安装插件', privilege=2, parent_class=PluginOperator -) +@operator.operator_class(name='get', help='安装插件', privilege=2, parent_class=PluginOperator) class PluginGetOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: if len(context.crt_params) == 0: - yield entities.CommandReturn( - error=errors.ParamNotEnoughError('请提供插件仓库地址') - ) + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址')) else: repo = context.crt_params[0] @@ -53,22 +45,14 @@ class PluginGetOperator(operator.CommandOperator): yield entities.CommandReturn(text='插件安装成功,请重启程序以加载插件') except Exception as e: traceback.print_exc() - yield entities.CommandReturn( - error=errors.CommandError('插件安装失败: ' + str(e)) - ) + yield entities.CommandReturn(error=errors.CommandError('插件安装失败: ' + str(e))) -@operator.operator_class( - name='update', help='更新插件', privilege=2, parent_class=PluginOperator -) +@operator.operator_class(name='update', help='更新插件', privilege=2, parent_class=PluginOperator) class PluginUpdateOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: if len(context.crt_params) == 0: - yield entities.CommandReturn( - error=errors.ParamNotEnoughError('请提供插件名称') - ) + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) else: plugin_name = context.crt_params[0] @@ -78,27 +62,17 @@ class PluginUpdateOperator(operator.CommandOperator): if plugin_container is not None: yield entities.CommandReturn(text='正在更新插件...') await self.ap.plugin_mgr.update_plugin(plugin_name) - yield entities.CommandReturn( - text='插件更新成功,请重启程序以加载插件' - ) + yield entities.CommandReturn(text='插件更新成功,请重启程序以加载插件') else: - yield entities.CommandReturn( - error=errors.CommandError('插件更新失败: 未找到插件') - ) + yield entities.CommandReturn(error=errors.CommandError('插件更新失败: 未找到插件')) except Exception as e: traceback.print_exc() - yield entities.CommandReturn( - error=errors.CommandError('插件更新失败: ' + str(e)) - ) + yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e))) -@operator.operator_class( - name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator -) +@operator.operator_class(name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator) class PluginUpdateAllOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: try: plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()] @@ -111,32 +85,20 @@ class PluginUpdateAllOperator(operator.CommandOperator): updated.append(plugin_name) except Exception as e: traceback.print_exc() - yield entities.CommandReturn( - error=errors.CommandError('插件更新失败: ' + str(e)) - ) - yield entities.CommandReturn( - text='已更新插件: {}'.format(', '.join(updated)) - ) + yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e))) + yield entities.CommandReturn(text='已更新插件: {}'.format(', '.join(updated))) else: yield entities.CommandReturn(text='没有可更新的插件') except Exception as e: traceback.print_exc() - yield entities.CommandReturn( - error=errors.CommandError('插件更新失败: ' + str(e)) - ) + yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e))) -@operator.operator_class( - name='del', help='删除插件', privilege=2, parent_class=PluginOperator -) +@operator.operator_class(name='del', help='删除插件', privilege=2, parent_class=PluginOperator) class PluginDelOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: if len(context.crt_params) == 0: - yield entities.CommandReturn( - error=errors.ParamNotEnoughError('请提供插件名称') - ) + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) else: plugin_name = context.crt_params[0] @@ -146,79 +108,49 @@ class PluginDelOperator(operator.CommandOperator): if plugin_container is not None: yield entities.CommandReturn(text='正在删除插件...') await self.ap.plugin_mgr.uninstall_plugin(plugin_name) - yield entities.CommandReturn( - text='插件删除成功,请重启程序以加载插件' - ) + yield entities.CommandReturn(text='插件删除成功,请重启程序以加载插件') else: - yield entities.CommandReturn( - error=errors.CommandError('插件删除失败: 未找到插件') - ) + yield entities.CommandReturn(error=errors.CommandError('插件删除失败: 未找到插件')) except Exception as e: traceback.print_exc() - yield entities.CommandReturn( - error=errors.CommandError('插件删除失败: ' + str(e)) - ) + yield entities.CommandReturn(error=errors.CommandError('插件删除失败: ' + str(e))) -@operator.operator_class( - name='on', help='启用插件', privilege=2, parent_class=PluginOperator -) +@operator.operator_class(name='on', help='启用插件', privilege=2, parent_class=PluginOperator) class PluginEnableOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: if len(context.crt_params) == 0: - yield entities.CommandReturn( - error=errors.ParamNotEnoughError('请提供插件名称') - ) + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) else: plugin_name = context.crt_params[0] try: if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True): - yield entities.CommandReturn( - text='已启用插件: {}'.format(plugin_name) - ) + yield entities.CommandReturn(text='已启用插件: {}'.format(plugin_name)) else: yield entities.CommandReturn( - error=errors.CommandError( - '插件状态修改失败: 未找到插件 {}'.format(plugin_name) - ) + error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name)) ) except Exception as e: traceback.print_exc() - yield entities.CommandReturn( - error=errors.CommandError('插件状态修改失败: ' + str(e)) - ) + yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e))) -@operator.operator_class( - name='off', help='禁用插件', privilege=2, parent_class=PluginOperator -) +@operator.operator_class(name='off', help='禁用插件', privilege=2, parent_class=PluginOperator) class PluginDisableOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: if len(context.crt_params) == 0: - yield entities.CommandReturn( - error=errors.ParamNotEnoughError('请提供插件名称') - ) + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) else: plugin_name = context.crt_params[0] try: if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False): - yield entities.CommandReturn( - text='已禁用插件: {}'.format(plugin_name) - ) + yield entities.CommandReturn(text='已禁用插件: {}'.format(plugin_name)) else: yield entities.CommandReturn( - error=errors.CommandError( - '插件状态修改失败: 未找到插件 {}'.format(plugin_name) - ) + error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name)) ) except Exception as e: traceback.print_exc() - yield entities.CommandReturn( - error=errors.CommandError('插件状态修改失败: ' + str(e)) - ) + yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e))) diff --git a/pkg/command/operators/prompt.py b/pkg/command/operators/prompt.py index 41f42de4..fdcba2bd 100644 --- a/pkg/command/operators/prompt.py +++ b/pkg/command/operators/prompt.py @@ -7,14 +7,10 @@ from .. import operator, entities, errors @operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt') class PromptOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: """执行""" if context.session.using_conversation is None: - yield entities.CommandReturn( - error=errors.CommandOperationError('当前没有对话') - ) + yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) else: reply_str = '当前对话所有内容:\n\n' diff --git a/pkg/command/operators/resend.py b/pkg/command/operators/resend.py index 44e5a35c..39789fef 100644 --- a/pkg/command/operators/resend.py +++ b/pkg/command/operators/resend.py @@ -5,13 +5,9 @@ import typing from .. import operator, entities, errors -@operator.operator_class( - name='resend', help='重发当前会话的最后一条消息', usage='!resend' -) +@operator.operator_class(name='resend', help='重发当前会话的最后一条消息', usage='!resend') class ResendOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: # 回滚到最后一条用户message前 if context.session.using_conversation is None: yield entities.CommandReturn(error=errors.CommandError('当前没有对话')) diff --git a/pkg/command/operators/reset.py b/pkg/command/operators/reset.py index 7ef54e08..008143a1 100644 --- a/pkg/command/operators/reset.py +++ b/pkg/command/operators/reset.py @@ -7,9 +7,7 @@ from .. import operator, entities @operator.operator_class(name='reset', help='重置当前会话', usage='!reset') class ResetOperator(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: """执行""" context.session.using_conversation = None diff --git a/pkg/command/operators/update.py b/pkg/command/operators/update.py index 775ee26a..9eda3a6c 100644 --- a/pkg/command/operators/update.py +++ b/pkg/command/operators/update.py @@ -8,9 +8,7 @@ from .. import operator, entities, errors @operator.operator_class(name='update', help='更新程序', usage='!update', privilege=2) class UpdateCommand(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: try: yield entities.CommandReturn(text='正在进行更新...') if await self.ap.ver_mgr.update_all(): @@ -19,6 +17,4 @@ class UpdateCommand(operator.CommandOperator): yield entities.CommandReturn(text='当前已是最新版本') except Exception as e: traceback.print_exc() - yield entities.CommandReturn( - error=errors.CommandError('更新失败: ' + str(e)) - ) + yield entities.CommandReturn(error=errors.CommandError('更新失败: ' + str(e))) diff --git a/pkg/command/operators/version.py b/pkg/command/operators/version.py index 267b1113..200875aa 100644 --- a/pkg/command/operators/version.py +++ b/pkg/command/operators/version.py @@ -7,9 +7,7 @@ from .. import operator, entities @operator.operator_class(name='version', help='显示版本信息', usage='!version') class VersionCommand(operator.CommandOperator): - async def execute( - self, context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}' try: diff --git a/pkg/config/manager.py b/pkg/config/manager.py index 2385c6b5..c2e6bdf4 100644 --- a/pkg/config/manager.py +++ b/pkg/config/manager.py @@ -41,9 +41,7 @@ class ConfigManager: self.file.save_sync(self.data) -async def load_python_module_config( - config_name: str, template_name: str, completion: bool = True -) -> ConfigManager: +async def load_python_module_config(config_name: str, template_name: str, completion: bool = True) -> ConfigManager: """加载Python模块配置文件 Args: diff --git a/pkg/core/app.py b/pkg/core/app.py index beb5415c..692e2b8a 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -160,9 +160,7 @@ class Application: """打印访问 webui 的提示""" if not os.path.exists(os.path.join('.', 'web/out')): - self.logger.warning( - 'WebUI 文件缺失,请根据文档获取:https://docs.langbot.app/webui/intro.html' - ) + self.logger.warning('WebUI 文件缺失,请根据文档获取:https://docs.langbot.app/webui/intro.html') return host_ip = '127.0.0.1' diff --git a/pkg/core/bootutils/log.py b/pkg/core/bootutils/log.py index df65e1ba..eb6806fa 100644 --- a/pkg/core/bootutils/log.py +++ b/pkg/core/bootutils/log.py @@ -26,9 +26,7 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging. if constants.debug_mode: level = logging.DEBUG - log_file_name = 'data/logs/langbot-%s.log' % time.strftime( - '%Y-%m-%d', time.localtime() - ) + log_file_name = 'data/logs/langbot-%s.log' % time.strftime('%Y-%m-%d', time.localtime()) qcg_logger = logging.getLogger('langbot') @@ -43,9 +41,7 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging. stream_handler = logging.StreamHandler(sys.stdout) # stream_handler.setLevel(level) # stream_handler.setFormatter(color_formatter) - stream_handler.stream = open( - sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1 - ) + stream_handler.stream = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1) log_handlers: list[logging.Handler] = [ stream_handler, diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 9eddc935..e2ea3d45 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -87,8 +87,7 @@ class Query(pydantic.BaseModel): """使用的函数,由前置处理器阶段设置""" resp_messages: ( - typing.Optional[list[llm_entities.Message]] - | typing.Optional[list[platform_message.MessageChain]] + typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]] ) = [] """由Process阶段生成的回复消息对象列表""" @@ -130,13 +129,9 @@ class Conversation(pydantic.BaseModel): messages: list[llm_entities.Message] - create_time: typing.Optional[datetime.datetime] = pydantic.Field( - default_factory=datetime.datetime.now - ) + create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) - update_time: typing.Optional[datetime.datetime] = pydantic.Field( - default_factory=datetime.datetime.now - ) + update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) use_llm_model: requester.RuntimeLLMModel @@ -162,17 +157,11 @@ class Session(pydantic.BaseModel): using_conversation: typing.Optional[Conversation] = None - conversations: typing.Optional[list[Conversation]] = pydantic.Field( - default_factory=list - ) + conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list) - create_time: typing.Optional[datetime.datetime] = pydantic.Field( - default_factory=datetime.datetime.now - ) + create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) - update_time: typing.Optional[datetime.datetime] = pydantic.Field( - default_factory=datetime.datetime.now - ) + update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) semaphore: typing.Optional[asyncio.Semaphore] = None """当前会话的信号量,用于限制并发""" diff --git a/pkg/core/migrations/m001_sensitive_word_migration.py b/pkg/core/migrations/m001_sensitive_word_migration.py index 72200346..35cb076f 100644 --- a/pkg/core/migrations/m001_sensitive_word_migration.py +++ b/pkg/core/migrations/m001_sensitive_word_migration.py @@ -11,16 +11,14 @@ class SensitiveWordMigration(migration.Migration): async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - return os.path.exists( - 'data/config/sensitive-words.json' - ) and not os.path.exists('data/metadata/sensitive-words.json') + return os.path.exists('data/config/sensitive-words.json') and not os.path.exists( + 'data/metadata/sensitive-words.json' + ) async def run(self): """执行迁移""" # 移动文件 - os.rename( - 'data/config/sensitive-words.json', 'data/metadata/sensitive-words.json' - ) + os.rename('data/config/sensitive-words.json', 'data/metadata/sensitive-words.json') # 重新加载配置 await self.ap.sensitive_meta.load_config() diff --git a/pkg/core/migrations/m002_openai_config_migration.py b/pkg/core/migrations/m002_openai_config_migration.py index 6892110f..9a35370c 100644 --- a/pkg/core/migrations/m002_openai_config_migration.py +++ b/pkg/core/migrations/m002_openai_config_migration.py @@ -23,9 +23,7 @@ class OpenAIConfigMigration(migration.Migration): self.ap.provider_cfg.data['keys']['openai'] = old_openai_config['api-keys'] - self.ap.provider_cfg.data['model'] = old_openai_config[ - 'chat-completions-params' - ]['model'] + self.ap.provider_cfg.data['model'] = old_openai_config['chat-completions-params']['model'] del old_openai_config['chat-completions-params']['model'] diff --git a/pkg/core/migrations/m007_qcg_center_url.py b/pkg/core/migrations/m007_qcg_center_url.py index b3fcd853..2783e079 100644 --- a/pkg/core/migrations/m007_qcg_center_url.py +++ b/pkg/core/migrations/m007_qcg_center_url.py @@ -15,8 +15,6 @@ class QCGCenterURLConfigMigration(migration.Migration): """执行迁移""" if 'qcg-center-url' not in self.ap.system_cfg.data: - self.ap.system_cfg.data['qcg-center-url'] = ( - 'https://api.qchatgpt.rockchin.top/api/v2' - ) + self.ap.system_cfg.data['qcg-center-url'] = 'https://api.qchatgpt.rockchin.top/api/v2' await self.ap.system_cfg.dump_config() diff --git a/pkg/core/migrations/m008_ad_fixwin_config_migrate.py b/pkg/core/migrations/m008_ad_fixwin_config_migrate.py index 96fd58e7..964e819b 100644 --- a/pkg/core/migrations/m008_ad_fixwin_config_migrate.py +++ b/pkg/core/migrations/m008_ad_fixwin_config_migrate.py @@ -9,9 +9,7 @@ class AdFixwinConfigMigration(migration.Migration): async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - return isinstance( - self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default'], int - ) + return isinstance(self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default'], int) async def run(self): """执行迁移""" @@ -19,9 +17,7 @@ class AdFixwinConfigMigration(migration.Migration): for session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']: temp_dict = { 'window-size': 60, - 'limit': self.ap.pipeline_cfg.data['rate-limit']['fixwin'][ - session_name - ], + 'limit': self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name], } self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name] = temp_dict diff --git a/pkg/core/migrations/m013_http_api_config.py b/pkg/core/migrations/m013_http_api_config.py index 55aff2b9..80e7b74f 100644 --- a/pkg/core/migrations/m013_http_api_config.py +++ b/pkg/core/migrations/m013_http_api_config.py @@ -9,10 +9,7 @@ class HttpApiConfigMigration(migration.Migration): async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - return ( - 'http-api' not in self.ap.system_cfg.data - or 'persistence' not in self.ap.system_cfg.data - ) + return 'http-api' not in self.ap.system_cfg.data or 'persistence' not in self.ap.system_cfg.data async def run(self): """执行迁移""" diff --git a/pkg/core/migrations/m017_dify_api_timeout_params.py b/pkg/core/migrations/m017_dify_api_timeout_params.py index 7ce9133c..67635fb5 100644 --- a/pkg/core/migrations/m017_dify_api_timeout_params.py +++ b/pkg/core/migrations/m017_dify_api_timeout_params.py @@ -11,8 +11,7 @@ class DifyAPITimeoutParamsMigration(migration.Migration): """判断当前环境是否需要运行此迁移""" return ( 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['chat'] - or 'timeout' - not in self.ap.provider_cfg.data['dify-service-api']['workflow'] + or 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['workflow'] or 'agent' not in self.ap.provider_cfg.data['dify-service-api'] ) diff --git a/pkg/core/migrations/m023_siliconflow_config.py b/pkg/core/migrations/m023_siliconflow_config.py index fdf696eb..31b9ee8e 100644 --- a/pkg/core/migrations/m023_siliconflow_config.py +++ b/pkg/core/migrations/m023_siliconflow_config.py @@ -10,9 +10,7 @@ class SiliconFlowConfigMigration(migration.Migration): async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - return ( - 'siliconflow-chat-completions' not in self.ap.provider_cfg.data['requester'] - ) + return 'siliconflow-chat-completions' not in self.ap.provider_cfg.data['requester'] async def run(self): """执行迁移""" diff --git a/pkg/core/migrations/m033_dify_thinking_config.py b/pkg/core/migrations/m033_dify_thinking_config.py index d25a4aad..7269765a 100644 --- a/pkg/core/migrations/m033_dify_thinking_config.py +++ b/pkg/core/migrations/m033_dify_thinking_config.py @@ -13,17 +13,12 @@ class DifyThinkingConfigMigration(migration.Migration): if 'options' not in self.ap.provider_cfg.data['dify-service-api']: return True - if ( - 'convert-thinking-tips' - not in self.ap.provider_cfg.data['dify-service-api']['options'] - ): + if 'convert-thinking-tips' not in self.ap.provider_cfg.data['dify-service-api']['options']: return True return False async def run(self): """执行迁移""" - self.ap.provider_cfg.data['dify-service-api']['options'] = { - 'convert-thinking-tips': 'plain' - } + self.ap.provider_cfg.data['dify-service-api']['options'] = {'convert-thinking-tips': 'plain'} await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/migrations/m034_gewechat_file_url_config.py b/pkg/core/migrations/m034_gewechat_file_url_config.py index 8c3e0a83..512b75b1 100644 --- a/pkg/core/migrations/m034_gewechat_file_url_config.py +++ b/pkg/core/migrations/m034_gewechat_file_url_config.py @@ -24,8 +24,6 @@ class GewechatFileUrlConfigMigration(migration.Migration): if adapter['adapter'] == 'gewechat': if 'gewechat_file_url' not in adapter: parsed_url = urlparse(adapter['gewechat_url']) - adapter['gewechat_file_url'] = ( - f'{parsed_url.scheme}://{parsed_url.hostname}:2532' - ) + adapter['gewechat_file_url'] = f'{parsed_url.scheme}://{parsed_url.hostname}:2532' await self.ap.platform_cfg.dump_config() diff --git a/pkg/core/migrations/m038_tg_dingtalk_markdown.py b/pkg/core/migrations/m038_tg_dingtalk_markdown.py index 1123c6b2..c0a85a44 100644 --- a/pkg/core/migrations/m038_tg_dingtalk_markdown.py +++ b/pkg/core/migrations/m038_tg_dingtalk_markdown.py @@ -3,24 +3,23 @@ from __future__ import annotations from .. import migration -@migration.migration_class("tg-dingtalk-markdown", 38) +@migration.migration_class('tg-dingtalk-markdown', 38) class TgDingtalkMarkdownMigration(migration.Migration): """迁移""" async def need_migrate(self) -> bool: """判断当前环境是否需要运行此迁移""" - + for adapter in self.ap.platform_cfg.data['platform-adapters']: - if adapter['adapter'] in ['dingtalk','telegram']: + if adapter['adapter'] in ['dingtalk', 'telegram']: if 'markdown_card' not in adapter: return True return False - + async def run(self): """执行迁移""" for adapter in self.ap.platform_cfg.data['platform-adapters']: - if adapter['adapter'] in ['dingtalk','telegram']: + if adapter['adapter'] in ['dingtalk', 'telegram']: if 'markdown_card' not in adapter: adapter['markdown_card'] = False await self.ap.platform_cfg.dump_config() - \ No newline at end of file diff --git a/pkg/core/migrations/m039_modelscope_cfg_completion.py b/pkg/core/migrations/m039_modelscope_cfg_completion.py index 8e574911..9eec0344 100644 --- a/pkg/core/migrations/m039_modelscope_cfg_completion.py +++ b/pkg/core/migrations/m039_modelscope_cfg_completion.py @@ -3,20 +3,19 @@ from __future__ import annotations from .. import migration -@migration.migration_class("modelscope-config-completion", 39) +@migration.migration_class('modelscope-config-completion', 39) class ModelScopeConfigCompletionMigration(migration.Migration): - """ModelScope配置迁移 - """ + """ModelScope配置迁移""" async def need_migrate(self) -> bool: - """判断当前环境是否需要运行此迁移 - """ - return 'modelscope-chat-completions' not in self.ap.provider_cfg.data['requester'] \ + """判断当前环境是否需要运行此迁移""" + return ( + 'modelscope-chat-completions' not in self.ap.provider_cfg.data['requester'] or 'modelscope' not in self.ap.provider_cfg.data['keys'] + ) async def run(self): - """执行迁移 - """ + """执行迁移""" if 'modelscope-chat-completions' not in self.ap.provider_cfg.data['requester']: self.ap.provider_cfg.data['requester']['modelscope-chat-completions'] = { 'base-url': 'https://api-inference.modelscope.cn/v1', diff --git a/pkg/core/migrations/m040_ppio_config.py b/pkg/core/migrations/m040_ppio_config.py index cd218d87..d4d82b98 100644 --- a/pkg/core/migrations/m040_ppio_config.py +++ b/pkg/core/migrations/m040_ppio_config.py @@ -3,20 +3,19 @@ from __future__ import annotations from .. import migration -@migration.migration_class("ppio-config", 40) +@migration.migration_class('ppio-config', 40) class PPIOConfigMigration(migration.Migration): - """PPIO配置迁移 - """ + """PPIO配置迁移""" async def need_migrate(self) -> bool: - """判断当前环境是否需要运行此迁移 - """ - return 'ppio-chat-completions' not in self.ap.provider_cfg.data['requester'] \ + """判断当前环境是否需要运行此迁移""" + return ( + 'ppio-chat-completions' not in self.ap.provider_cfg.data['requester'] or 'ppio' not in self.ap.provider_cfg.data['keys'] + ) async def run(self): - """执行迁移 - """ + """执行迁移""" if 'ppio-chat-completions' not in self.ap.provider_cfg.data['requester']: self.ap.provider_cfg.data['requester']['ppio-chat-completions'] = { 'base-url': 'https://api.ppinfra.com/v3/openai', diff --git a/pkg/core/taskmgr.py b/pkg/core/taskmgr.py index ae2394cf..0f756118 100644 --- a/pkg/core/taskmgr.py +++ b/pkg/core/taskmgr.py @@ -35,9 +35,7 @@ class TaskContext: if action is not None: self.set_current_action(action) - self._log( - f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} | {self.current_action} | {msg}' - ) + self._log(f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} | {self.current_action} | {msg}') def to_dict(self) -> dict: return {'current_action': self.current_action, 'log': self.log} @@ -104,9 +102,7 @@ class TaskWrapper: name: str = '', label: str = '', context: TaskContext = None, - scopes: list[core_entities.LifecycleControlScope] = [ - core_entities.LifecycleControlScope.APPLICATION - ], + scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION], ): self.id = TaskWrapper._id_index TaskWrapper._id_index += 1 @@ -141,7 +137,9 @@ class TaskWrapper: exception_traceback = 'Traceback (most recent call last):\n' for frame in self.task_stack: - exception_traceback += f' File "{frame.f_code.co_filename}", line {frame.f_lineno}, in {frame.f_code.co_name}\n' + exception_traceback += ( + f' File "{frame.f_code.co_filename}", line {frame.f_lineno}, in {frame.f_code.co_name}\n' + ) exception_traceback += f' {self.assume_exception().__str__()}\n' @@ -156,13 +154,9 @@ class TaskWrapper: 'runtime': { 'done': self.task.done(), 'state': self.task._state, - 'exception': self.assume_exception().__str__() - if self.assume_exception() is not None - else None, + 'exception': self.assume_exception().__str__() if self.assume_exception() is not None else None, 'exception_traceback': exception_traceback, - 'result': self.assume_result().__str__() - if self.assume_result() is not None - else None, + 'result': self.assume_result().__str__() if self.assume_result() is not None else None, }, } @@ -191,13 +185,9 @@ class AsyncTaskManager: name: str = '', label: str = '', context: TaskContext = None, - scopes: list[core_entities.LifecycleControlScope] = [ - core_entities.LifecycleControlScope.APPLICATION - ], + scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION], ) -> TaskWrapper: - wrapper = TaskWrapper( - self.ap, coro, task_type, kind, name, label, context, scopes - ) + wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context, scopes) self.tasks.append(wrapper) return wrapper @@ -208,9 +198,7 @@ class AsyncTaskManager: name: str = '', label: str = '', context: TaskContext = None, - scopes: list[core_entities.LifecycleControlScope] = [ - core_entities.LifecycleControlScope.APPLICATION - ], + scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION], ) -> TaskWrapper: return self.create_task(coro, 'user', kind, name, label, context, scopes) @@ -225,9 +213,7 @@ class AsyncTaskManager: type: str = None, ) -> dict: return { - 'tasks': [ - t.to_dict() for t in self.tasks if type is None or t.task_type == type - ], + 'tasks': [t.to_dict() for t in self.tasks if type is None or t.task_type == type], 'id_index': TaskWrapper._id_index, } diff --git a/pkg/discover/engine.py b/pkg/discover/engine.py index be23a4ac..2224ba48 100644 --- a/pkg/discover/engine.py +++ b/pkg/discover/engine.py @@ -114,9 +114,7 @@ class Component(pydantic.BaseModel): _execution: Execution """组件执行""" - def __init__( - self, owner: str, manifest: typing.Dict[str, typing.Any], rel_path: str - ): + def __init__(self, owner: str, manifest: typing.Dict[str, typing.Any], rel_path: str): super().__init__( owner=owner, manifest=manifest, @@ -125,19 +123,12 @@ class Component(pydantic.BaseModel): ) self._metadata = Metadata(**manifest['metadata']) self._spec = manifest['spec'] - self._execution = ( - Execution(**manifest['execution']) if 'execution' in manifest else None - ) + self._execution = Execution(**manifest['execution']) if 'execution' in manifest else None @classmethod def is_component_manifest(cls, manifest: typing.Dict[str, typing.Any]) -> bool: """判断是否为组件清单""" - return ( - 'apiVersion' in manifest - and 'kind' in manifest - and 'metadata' in manifest - and 'spec' in manifest - ) + return 'apiVersion' in manifest and 'kind' in manifest and 'metadata' in manifest and 'spec' in manifest @property def kind(self) -> str: @@ -200,9 +191,7 @@ class ComponentDiscoveryEngine: def __init__(self, ap: app.Application): self.ap = ap - def load_component_manifest( - self, path: str, owner: str = 'builtin', no_save: bool = False - ) -> Component | None: + def load_component_manifest(self, path: str, owner: str = 'builtin', no_save: bool = False) -> Component | None: """加载组件清单""" with open(path, 'r', encoding='utf-8') as f: manifest = yaml.safe_load(f) @@ -229,18 +218,12 @@ class ComponentDiscoveryEngine: if depth > max_depth: return for file in os.listdir(path): - if (not os.path.isdir(os.path.join(path, file))) and ( - file.endswith('.yaml') or file.endswith('.yml') - ): - comp = self.load_component_manifest( - os.path.join(path, file), owner, no_save - ) + if (not os.path.isdir(os.path.join(path, file))) and (file.endswith('.yaml') or file.endswith('.yml')): + comp = self.load_component_manifest(os.path.join(path, file), owner, no_save) if comp is not None: components.append(comp) elif os.path.isdir(os.path.join(path, file)): - recursive_load_component_manifests_in_dir( - os.path.join(path, file), depth + 1 - ) + recursive_load_component_manifests_in_dir(os.path.join(path, file), depth + 1) recursive_load_component_manifests_in_dir(path) return components @@ -259,18 +242,12 @@ class ComponentDiscoveryEngine: for dir in group['fromDirs']: path = dir['path'] max_depth = dir['maxDepth'] if 'maxDepth' in dir else 1 - components.extend( - self.load_component_manifests_in_dir( - path, owner, no_save, max_depth - ) - ) + components.extend(self.load_component_manifests_in_dir(path, owner, no_save, max_depth)) return components def discover_blueprint(self, blueprint_manifest_path: str, owner: str = 'builtin'): """发现蓝图""" - blueprint_manifest = self.load_component_manifest( - blueprint_manifest_path, owner, no_save=True - ) + blueprint_manifest = self.load_component_manifest(blueprint_manifest_path, owner, no_save=True) if blueprint_manifest is None: raise ValueError(f'Invalid blueprint manifest: {blueprint_manifest_path}') assert blueprint_manifest.kind == 'Blueprint', '`Kind` must be `Blueprint`' @@ -297,9 +274,7 @@ class ComponentDiscoveryEngine: return [] return self.components[kind] - def find_components( - self, kind: str, component_list: typing.List[Component] - ) -> typing.List[Component]: + def find_components(self, kind: str, component_list: typing.List[Component]) -> typing.List[Component]: """查找组件""" result: typing.List[Component] = [] for component in component_list: diff --git a/pkg/entity/persistence/bot.py b/pkg/entity/persistence/bot.py index 86932cac..3c08f4ec 100644 --- a/pkg/entity/persistence/bot.py +++ b/pkg/entity/persistence/bot.py @@ -16,9 +16,7 @@ class Bot(Base): enable = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False) use_pipeline_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) use_pipeline_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) - created_at = sqlalchemy.Column( - sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now() - ) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) updated_at = sqlalchemy.Column( sqlalchemy.DateTime, nullable=False, diff --git a/pkg/entity/persistence/model.py b/pkg/entity/persistence/model.py index 65e016f3..9eb2ccef 100644 --- a/pkg/entity/persistence/model.py +++ b/pkg/entity/persistence/model.py @@ -16,9 +16,7 @@ class LLMModel(Base): api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False) abilities = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=[]) extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={}) - created_at = sqlalchemy.Column( - sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now() - ) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) updated_at = sqlalchemy.Column( sqlalchemy.DateTime, nullable=False, diff --git a/pkg/entity/persistence/pipeline.py b/pkg/entity/persistence/pipeline.py index ca854203..56e2cae9 100644 --- a/pkg/entity/persistence/pipeline.py +++ b/pkg/entity/persistence/pipeline.py @@ -11,9 +11,7 @@ class LegacyPipeline(Base): uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) description = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) - created_at = sqlalchemy.Column( - sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now() - ) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) updated_at = sqlalchemy.Column( sqlalchemy.DateTime, nullable=False, @@ -35,9 +33,7 @@ class PipelineRunRecord(Base): uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True) pipeline_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) status = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) - created_at = sqlalchemy.Column( - sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now() - ) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) updated_at = sqlalchemy.Column( sqlalchemy.DateTime, nullable=False, diff --git a/pkg/entity/persistence/plugin.py b/pkg/entity/persistence/plugin.py index 94d6b8b4..30db6bd6 100644 --- a/pkg/entity/persistence/plugin.py +++ b/pkg/entity/persistence/plugin.py @@ -13,9 +13,7 @@ class PluginSetting(Base): enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True) priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0) config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict) - created_at = sqlalchemy.Column( - sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now() - ) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) updated_at = sqlalchemy.Column( sqlalchemy.DateTime, nullable=False, diff --git a/pkg/entity/persistence/user.py b/pkg/entity/persistence/user.py index a0d9f168..04a5b374 100644 --- a/pkg/entity/persistence/user.py +++ b/pkg/entity/persistence/user.py @@ -9,9 +9,7 @@ class User(Base): id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) user = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) password = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) - created_at = sqlalchemy.Column( - sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now() - ) + created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) updated_at = sqlalchemy.Column( sqlalchemy.DateTime, nullable=False, diff --git a/pkg/persistence/databases/sqlite.py b/pkg/persistence/databases/sqlite.py index 1b12def8..7b095e61 100644 --- a/pkg/persistence/databases/sqlite.py +++ b/pkg/persistence/databases/sqlite.py @@ -11,6 +11,4 @@ class SQLiteDatabaseManager(database.BaseDatabaseManager): async def initialize(self) -> None: sqlite_path = 'data/langbot.db' - self.engine = sqlalchemy_asyncio.create_async_engine( - f'sqlite+aiosqlite:///{sqlite_path}' - ) + self.engine = sqlalchemy_asyncio.create_async_engine(f'sqlite+aiosqlite:///{sqlite_path}') diff --git a/pkg/persistence/mgr.py b/pkg/persistence/mgr.py index 3a66762a..69382e6d 100644 --- a/pkg/persistence/mgr.py +++ b/pkg/persistence/mgr.py @@ -58,24 +58,18 @@ class PersistenceManager: for item in metadata.initial_metadata: # check if the item exists result = await self.execute_async( - sqlalchemy.select(metadata.Metadata).where( - metadata.Metadata.key == item['key'] - ) + sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == item['key']) ) row = result.first() if row is None: - await self.execute_async( - sqlalchemy.insert(metadata.Metadata).values(item) - ) + await self.execute_async(sqlalchemy.insert(metadata.Metadata).values(item)) # write default pipeline result = await self.execute_async(sqlalchemy.select(pipeline.LegacyPipeline)) if result.first() is None: self.ap.logger.info('Creating default pipeline...') - pipeline_config = json.load( - open('templates/default-pipeline-config.json', 'r', encoding='utf-8') - ) + pipeline_config = json.load(open('templates/default-pipeline-config.json', 'r', encoding='utf-8')) pipeline_data = { 'uuid': str(uuid.uuid4()), @@ -87,16 +81,12 @@ class PersistenceManager: 'config': pipeline_config, } - await self.execute_async( - sqlalchemy.insert(pipeline.LegacyPipeline).values(pipeline_data) - ) + await self.execute_async(sqlalchemy.insert(pipeline.LegacyPipeline).values(pipeline_data)) # ================================= # run migrations database_version = await self.execute_async( - sqlalchemy.select(metadata.Metadata).where( - metadata.Metadata.key == 'database_version' - ) + sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version') ) database_version = int(database_version.fetchone()[1]) @@ -122,17 +112,11 @@ class PersistenceManager: .values({'value': str(migration_instance.number)}) ) last_migration_number = migration_instance.number - self.ap.logger.info( - f'Migration {migration_instance.number} completed.' - ) + self.ap.logger.info(f'Migration {migration_instance.number} completed.') - self.ap.logger.info( - f'Successfully upgraded database to version {last_migration_number}.' - ) + self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.') - async def execute_async( - self, *args, **kwargs - ) -> sqlalchemy.engine.cursor.CursorResult: + async def execute_async(self, *args, **kwargs) -> sqlalchemy.engine.cursor.CursorResult: async with self.get_db_engine().connect() as conn: result = await conn.execute(*args, **kwargs) await conn.commit() @@ -141,9 +125,7 @@ class PersistenceManager: def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine: return self.db.get_engine() - def serialize_model( - self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base - ) -> dict: + def serialize_model(self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base) -> dict: return { column.name: getattr(data, column.name) if not isinstance(getattr(data, column.name), (datetime.datetime)) diff --git a/pkg/pipeline/bansess/bansess.py b/pkg/pipeline/bansess/bansess.py index dad6a3ab..3b927a55 100644 --- a/pkg/pipeline/bansess/bansess.py +++ b/pkg/pipeline/bansess/bansess.py @@ -14,9 +14,7 @@ class BanSessionCheckStage(stage.PipelineStage): async def initialize(self, pipeline_config: dict): pass - async def process( - self, query: core_entities.Query, stage_inst_name: str - ) -> entities.StageProcessResult: + async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: found = False mode = query.pipeline_config['trigger']['access-control']['mode'] @@ -41,11 +39,7 @@ class BanSessionCheckStage(stage.PipelineStage): ctn = not found return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE - if ctn - else entities.ResultType.INTERRUPT, + result_type=entities.ResultType.CONTINUE if ctn else entities.ResultType.INTERRUPT, new_query=query, - console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}' - if not ctn - else '', + console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}' if not ctn else '', ) diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index 6547cb16..879b1295 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -65,9 +65,7 @@ class ContentFilterStage(stage.PipelineStage): """ if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg': - return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) else: for filter in self.filter_chain: if filter_entities.EnableStage.PRE in filter.enable_stages: @@ -86,13 +84,9 @@ class ContentFilterStage(stage.PipelineStage): elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个 message = result.replacement - query.message_chain = platform_message.MessageChain( - platform_message.Plain(message) - ) + query.message_chain = platform_message.MessageChain(platform_message.Plain(message)) - return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) async def _post_process( self, @@ -103,9 +97,7 @@ class ContentFilterStage(stage.PipelineStage): 只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter """ if query.pipeline_config['safety']['content-filter']['scope'] == 'income-msg': - return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) else: message = message.strip() for filter in self.filter_chain: @@ -127,13 +119,9 @@ class ContentFilterStage(stage.PipelineStage): query.resp_messages[-1].content = message - return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) - async def process( - self, query: core_entities.Query, stage_inst_name: str - ) -> entities.StageProcessResult: + async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: """处理""" if stage_inst_name == 'PreContentFilterStage': contain_non_text = False @@ -147,9 +135,7 @@ class ContentFilterStage(stage.PipelineStage): if contain_non_text: self.ap.logger.debug('消息中包含非文本消息,跳过内容过滤器检查。') - return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) return await self._pre_process(str(query.message_chain).strip(), query) elif stage_inst_name == 'PostContentFilterStage': @@ -162,8 +148,6 @@ class ContentFilterStage(stage.PipelineStage): self.ap.logger.debug( 'resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。' ) - return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) else: raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}') diff --git a/pkg/pipeline/cntfilter/filter.py b/pkg/pipeline/cntfilter/filter.py index ae7ceb79..0a3ceaae 100644 --- a/pkg/pipeline/cntfilter/filter.py +++ b/pkg/pipeline/cntfilter/filter.py @@ -60,9 +60,7 @@ class ContentFilter(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def process( - self, query: core_entities.Query, message: str = None, image_url=None - ) -> entities.FilterResult: + async def process(self, query: core_entities.Query, message: str = None, image_url=None) -> entities.FilterResult: """处理消息 分为前后阶段,具体取决于 enable_stages 的值。 diff --git a/pkg/pipeline/cntfilter/filters/baiduexamine.py b/pkg/pipeline/cntfilter/filters/baiduexamine.py index c3776bc9..9637aec2 100644 --- a/pkg/pipeline/cntfilter/filters/baiduexamine.py +++ b/pkg/pipeline/cntfilter/filters/baiduexamine.py @@ -21,19 +21,13 @@ class BaiduCloudExamine(filter_model.ContentFilter): BAIDU_EXAMINE_TOKEN_URL, params={ 'grant_type': 'client_credentials', - 'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine'][ - 'api-key' - ], - 'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine'][ - 'api-secret' - ], + 'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'], + 'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret'], }, ) as resp: return (await resp.json())['access_token'] - async def process( - self, query: core_entities.Query, message: str - ) -> entities.FilterResult: + async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: async with aiohttp.ClientSession() as session: async with session.post( BAIDU_EXAMINE_URL.format(await self._get_token()), diff --git a/pkg/pipeline/cntfilter/filters/banwords.py b/pkg/pipeline/cntfilter/filters/banwords.py index 598fa299..916a1bc1 100644 --- a/pkg/pipeline/cntfilter/filters/banwords.py +++ b/pkg/pipeline/cntfilter/filters/banwords.py @@ -13,9 +13,7 @@ class BanWordFilter(filter_model.ContentFilter): async def initialize(self): pass - async def process( - self, query: core_entities.Query, message: str - ) -> entities.FilterResult: + async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: found = False for word in self.ap.sensitive_meta.data['words']: @@ -31,9 +29,7 @@ class BanWordFilter(filter_model.ContentFilter): self.ap.sensitive_meta.data['mask'] * len(match[i]), ) else: - message = message.replace( - match[i], self.ap.sensitive_meta.data['mask_word'] - ) + message = message.replace(match[i], self.ap.sensitive_meta.data['mask_word']) return entities.FilterResult( level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS, diff --git a/pkg/pipeline/cntfilter/filters/cntignore.py b/pkg/pipeline/cntfilter/filters/cntignore.py index cb563593..5e410e31 100644 --- a/pkg/pipeline/cntfilter/filters/cntignore.py +++ b/pkg/pipeline/cntfilter/filters/cntignore.py @@ -16,9 +16,7 @@ class ContentIgnore(filter_model.ContentFilter): entities.EnableStage.PRE, ] - async def process( - self, query: core_entities.Query, message: str - ) -> entities.FilterResult: + async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: if 'prefix' in query.pipeline_config['trigger']['ignore-rules']: for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']: if message.startswith(rule): diff --git a/pkg/pipeline/controller.py b/pkg/pipeline/controller.py index 2ad1690f..052187a2 100644 --- a/pkg/pipeline/controller.py +++ b/pkg/pipeline/controller.py @@ -16,9 +16,7 @@ class Controller: def __init__(self, ap: app.Application): self.ap = ap - self.semaphore = asyncio.Semaphore( - self.ap.instance_config.data['concurrency']['pipeline'] - ) + self.semaphore = asyncio.Semaphore(self.ap.instance_config.data['concurrency']['pipeline']) async def consumer(self): """事件处理循环""" @@ -32,9 +30,7 @@ class Controller: for query in queries: session = await self.ap.sess_mgr.get_session(query) - self.ap.logger.debug( - f'Checking query {query} session {session}' - ) + self.ap.logger.debug(f'Checking query {query} session {session}') if not session.semaphore.locked(): selected_query = query @@ -55,22 +51,16 @@ class Controller: # find pipeline # Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one. # Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected. - bot = await self.ap.platform_mgr.get_bot_by_uuid( - selected_query.bot_uuid - ) + bot = await self.ap.platform_mgr.get_bot_by_uuid(selected_query.bot_uuid) if bot: - pipeline = ( - await self.ap.pipeline_mgr.get_pipeline_by_uuid( - bot.bot_entity.use_pipeline_uuid - ) + pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid( + bot.bot_entity.use_pipeline_uuid ) if pipeline: await pipeline.run(selected_query) async with self.ap.query_pool: - ( - await self.ap.sess_mgr.get_session(selected_query) - ).semaphore.release() + (await self.ap.sess_mgr.get_session(selected_query)).semaphore.release() # 通知其他协程,有新的请求可以处理了 self.ap.query_pool.condition.notify_all() diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index ab20f3eb..5be20650 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -47,9 +47,7 @@ class LongTextProcessStage(stage.PipelineStage): '未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。' ) - pipeline_config['output']['long-text-processing'][ - 'strategy' - ] = 'forward' + pipeline_config['output']['long-text-processing']['strategy'] = 'forward' except Exception: traceback.print_exc() self.ap.logger.error( @@ -58,9 +56,7 @@ class LongTextProcessStage(stage.PipelineStage): ) ) - pipeline_config['output']['long-text-processing']['strategy'] = ( - 'forward' - ) + pipeline_config['output']['long-text-processing']['strategy'] = 'forward' for strategy_cls in strategy.preregistered_strategies: if strategy_cls.name == config['strategy']: @@ -71,9 +67,7 @@ class LongTextProcessStage(stage.PipelineStage): await self.strategy_impl.initialize() - async def process( - self, query: core_entities.Query, stage_inst_name: str - ) -> entities.StageProcessResult: + async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: # 检查是否包含非 Plain 组件 contains_non_plain = False @@ -89,11 +83,7 @@ class LongTextProcessStage(stage.PipelineStage): > query.pipeline_config['output']['long-text-processing']['threshold'] ): query.resp_message_chain[-1] = platform_message.MessageChain( - await self.strategy_impl.process( - str(query.resp_message_chain[-1]), query - ) + await self.strategy_impl.process(str(query.resp_message_chain[-1]), query) ) - return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) diff --git a/pkg/pipeline/longtext/strategies/forward.py b/pkg/pipeline/longtext/strategies/forward.py index 57084d76..6228d580 100644 --- a/pkg/pipeline/longtext/strategies/forward.py +++ b/pkg/pipeline/longtext/strategies/forward.py @@ -13,9 +13,7 @@ Forward = platform_message.Forward @strategy_model.strategy_class('forward') class ForwardComponentStrategy(strategy_model.LongTextStrategy): - async def process( - self, message: str, query: core_entities.Query - ) -> list[platform_message.MessageComponent]: + async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: display = ForwardMessageDiaplay( title='群聊的聊天记录', brief='[聊天记录]', diff --git a/pkg/pipeline/longtext/strategies/image.py b/pkg/pipeline/longtext/strategies/image.py index 26c4b731..3716e7c2 100644 --- a/pkg/pipeline/longtext/strategies/image.py +++ b/pkg/pipeline/longtext/strategies/image.py @@ -27,18 +27,14 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): encoding='utf-8', ) - async def process( - self, message: str, query: core_entities.Query - ) -> list[platform_message.MessageComponent]: + async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: img_path = self.text_to_image( text_str=message, save_as='temp/{}.png'.format(int(time.time())), query=query, ) - compressed_path, size = self.compress_image( - img_path, outfile='temp/{}_compressed.png'.format(int(time.time())) - ) + compressed_path, size = self.compress_image(img_path, outfile='temp/{}_compressed.png'.format(int(time.time()))) with open(compressed_path, 'rb') as f: img = f.read() @@ -165,10 +161,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): numbers = self.indexNumber(rest_text) for number in numbers: - if ( - number[1] < point < number[1] + len(number[0]) - and number[1] != 0 - ): + if number[1] < point < number[1] + len(number[0]) and number[1] != 0: point = number[1] break @@ -181,9 +174,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): else: continue # 准备画布 - img = Image.new( - 'RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255) - ) + img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255)) draw = ImageDraw.Draw(img, mode='RGBA') self.ap.logger.debug('正在绘制图片...') diff --git a/pkg/pipeline/longtext/strategy.py b/pkg/pipeline/longtext/strategy.py index 4e141045..0ddec0c6 100644 --- a/pkg/pipeline/longtext/strategy.py +++ b/pkg/pipeline/longtext/strategy.py @@ -49,9 +49,7 @@ class LongTextStrategy(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def process( - self, message: str, query: core_entities.Query - ) -> list[platform_message.MessageComponent]: + async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: """处理长文本 在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法 diff --git a/pkg/pipeline/msgtrun/msgtrun.py b/pkg/pipeline/msgtrun/msgtrun.py index 2595e289..c64f67fc 100644 --- a/pkg/pipeline/msgtrun/msgtrun.py +++ b/pkg/pipeline/msgtrun/msgtrun.py @@ -29,12 +29,8 @@ class ConversationMessageTruncator(stage.PipelineStage): else: raise ValueError(f'未知的截断器: {use_method}') - async def process( - self, query: core_entities.Query, stage_inst_name: str - ) -> entities.StageProcessResult: + async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: """处理""" query = await self.trun.truncate(query) - return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) diff --git a/pkg/pipeline/pipelinemgr.py b/pkg/pipeline/pipelinemgr.py index 8ca0d592..ee0dd191 100644 --- a/pkg/pipeline/pipelinemgr.py +++ b/pkg/pipeline/pipelinemgr.py @@ -79,26 +79,20 @@ class RuntimePipeline: query.pipeline_config = self.pipeline_entity.config await self.process_query(query) - async def _check_output( - self, query: entities.Query, result: pipeline_entities.StageProcessResult - ): + async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult): """检查输出""" if result.user_notice: # 处理str类型 if isinstance(result.user_notice, str): - result.user_notice = platform_message.MessageChain( - platform_message.Plain(result.user_notice) - ) + result.user_notice = platform_message.MessageChain(platform_message.Plain(result.user_notice)) elif isinstance(result.user_notice, list): result.user_notice = platform_message.MessageChain(*result.user_notice) if query.pipeline_config['output']['misc']['at-sender'] and isinstance( query.message_event, platform_events.GroupMessage ): - result.user_notice.insert( - 0, platform_message.At(query.message_event.sender.id) - ) + result.user_notice.insert(0, platform_message.At(query.message_event.sender.id)) await query.adapter.reply_message( message_source=query.message_event, @@ -150,37 +144,25 @@ class RuntimePipeline: result = await result if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果 - self.ap.logger.debug( - f'Stage {stage_container.inst_name} processed query {query} res {result}' - ) + self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} res {result}') await self._check_output(query, result) if result.result_type == pipeline_entities.ResultType.INTERRUPT: - self.ap.logger.debug( - f'Stage {stage_container.inst_name} interrupted query {query}' - ) + self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query}') break elif result.result_type == pipeline_entities.ResultType.CONTINUE: query = result.new_query elif isinstance(result, typing.AsyncGenerator): # 生成器 - self.ap.logger.debug( - f'Stage {stage_container.inst_name} processed query {query} gen' - ) + self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} gen') async for sub_result in result: - self.ap.logger.debug( - f'Stage {stage_container.inst_name} processed query {query} res {sub_result}' - ) + self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} res {sub_result}') await self._check_output(query, sub_result) if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT: - self.ap.logger.debug( - f'Stage {stage_container.inst_name} interrupted query {query}' - ) + self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query}') break - elif ( - sub_result.result_type == pipeline_entities.ResultType.CONTINUE - ): + elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE: query = sub_result.new_query await self._execute_from_stage(i + 1, query) break @@ -214,12 +196,8 @@ class RuntimePipeline: await self._execute_from_stage(0, query) except Exception as e: - inst_name = ( - query.current_stage.inst_name if query.current_stage else 'unknown' - ) - self.ap.logger.error( - f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}' - ) + inst_name = query.current_stage.inst_name if query.current_stage else 'unknown' + self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}') self.ap.logger.debug(f'Traceback: {traceback.format_exc()}') finally: self.ap.logger.debug(f'Query {query} processed') @@ -241,18 +219,14 @@ class PipelineManager: self.pipelines = [] async def initialize(self): - self.stage_dict = { - name: cls for name, cls in stage.preregistered_stages.items() - } + self.stage_dict = {name: cls for name, cls in stage.preregistered_stages.items()} await self.load_pipelines_from_db() async def load_pipelines_from_db(self): self.ap.logger.info('Loading pipelines from db...') - result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_pipeline.LegacyPipeline) - ) + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline)) pipelines = result.all() @@ -267,20 +241,14 @@ class PipelineManager: | dict, ): if isinstance(pipeline_entity, sqlalchemy.Row): - pipeline_entity = persistence_pipeline.LegacyPipeline( - **pipeline_entity._mapping - ) + pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity._mapping) elif isinstance(pipeline_entity, dict): pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity) # initialize stage containers according to pipeline_entity.stages stage_containers: list[StageInstContainer] = [] for stage_name in pipeline_entity.stages: - stage_containers.append( - StageInstContainer( - inst_name=stage_name, inst=self.stage_dict[stage_name](self.ap) - ) - ) + stage_containers.append(StageInstContainer(inst_name=stage_name, inst=self.stage_dict[stage_name](self.ap))) for stage_container in stage_containers: await stage_container.inst.initialize(pipeline_entity.config) diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index bab1127d..29371adc 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -44,9 +44,7 @@ class PreProcessor(stage.PipelineStage): query.use_llm_model = conversation.use_llm_model query.use_funcs = ( - conversation.use_funcs - if query.use_llm_model.model_entity.abilities.__contains__('tool_call') - else None + conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('tool_call') else None ) query.variables = { @@ -59,10 +57,9 @@ class PreProcessor(stage.PipelineStage): # Check if this model supports vision, if not, remove all images # TODO this checking should be performed in runner, and in this stage, the image should be reserved - if ( - query.pipeline_config['ai']['runner']['runner'] == 'local-agent' - and not query.use_llm_model.model_entity.abilities.__contains__('vision') - ): + if query.pipeline_config['ai']['runner'][ + 'runner' + ] == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'): for msg in query.messages: if isinstance(msg.content, list): for me in msg.content: @@ -78,14 +75,11 @@ class PreProcessor(stage.PipelineStage): content_list.append(llm_entities.ContentElement.from_text(me.text)) plain_text += me.text elif isinstance(me, platform_message.Image): - if ( - query.pipeline_config['ai']['runner']['runner'] != 'local-agent' - or query.use_llm_model.model_entity.abilities.__contains__('vision') - ): + if query.pipeline_config['ai']['runner'][ + 'runner' + ] != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__('vision'): if me.base64 is not None: - content_list.append( - llm_entities.ContentElement.from_image_base64(me.base64) - ) + content_list.append(llm_entities.ContentElement.from_image_base64(me.base64)) query.variables['user_message_text'] = plain_text @@ -104,6 +98,4 @@ class PreProcessor(stage.PipelineStage): query.prompt.messages = event_ctx.event.default_prompt query.messages = event_ctx.event.prompt - return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 3ad5c43e..35fa1611 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -49,13 +49,9 @@ class ChatMessageHandler(handler.MessageHandler): query.resp_messages.append(mc) - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) else: - yield entities.StageProcessResult( - result_type=entities.ResultType.INTERRUPT, new_query=query - ) + yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query) else: if event_ctx.event.alter is not None: # if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter @@ -69,34 +65,24 @@ class ChatMessageHandler(handler.MessageHandler): runner = r(self.ap, query.pipeline_config) break else: - raise ValueError( - f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}' - ) + raise ValueError(f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}') async for result in runner.run(query): query.resp_messages.append(result) - self.ap.logger.info( - f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}' - ) + self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}') if result.content is not None: text_length += len(result.content) - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) query.session.using_conversation.messages.append(query.user_message) query.session.using_conversation.messages.extend(query.resp_messages) except Exception as e: - self.ap.logger.error( - f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}' - ) + self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}') - hide_exception_info = query.pipeline_config['output']['misc'][ - 'hide-exception' - ] + hide_exception_info = query.pipeline_config['output']['misc']['hide-exception'] yield entities.StageProcessResult( result_type=entities.ResultType.INTERRUPT, diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py index af1357b5..cc0e9314 100644 --- a/pkg/pipeline/process/handlers/command.py +++ b/pkg/pipeline/process/handlers/command.py @@ -21,10 +21,7 @@ class CommandHandler(handler.MessageHandler): privilege = 1 - if ( - f'{query.launcher_type.value}_{query.launcher_id}' - in self.ap.instance_config.data['admins'] - ): + if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']: privilege = 2 spt = command_text.split(' ') @@ -54,25 +51,17 @@ class CommandHandler(handler.MessageHandler): query.resp_messages.append(mc) - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) else: - yield entities.StageProcessResult( - result_type=entities.ResultType.INTERRUPT, new_query=query - ) + yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query) else: if event_ctx.event.alter is not None: - query.message_chain = platform_message.MessageChain( - [platform_message.Plain(event_ctx.event.alter)] - ) + query.message_chain = platform_message.MessageChain([platform_message.Plain(event_ctx.event.alter)]) session = await self.ap.sess_mgr.get_session(query) - async for ret in self.ap.cmd_mgr.execute( - command_text=command_text, query=query, session=session - ): + async for ret in self.ap.cmd_mgr.execute(command_text=command_text, query=query, session=session): if ret.error is not None: query.resp_messages.append( llm_entities.Message( @@ -81,13 +70,9 @@ class CommandHandler(handler.MessageHandler): ) ) - self.ap.logger.info( - f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}' - ) + self.ap.logger.info(f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}') - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) elif ret.text is not None or ret.image_url is not None: content: list[llm_entities.ContentElement] = [] @@ -95,9 +80,7 @@ class CommandHandler(handler.MessageHandler): content.append(llm_entities.ContentElement.from_text(ret.text)) if ret.image_url is not None: - content.append( - llm_entities.ContentElement.from_image_url(ret.image_url) - ) + content.append(llm_entities.ContentElement.from_image_url(ret.image_url)) query.resp_messages.append( llm_entities.Message( @@ -108,10 +91,6 @@ class CommandHandler(handler.MessageHandler): self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}') - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) else: - yield entities.StageProcessResult( - result_type=entities.ResultType.INTERRUPT, new_query=query - ) + yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query) diff --git a/pkg/pipeline/ratelimit/algos/fixedwin.py b/pkg/pipeline/ratelimit/algos/fixedwin.py index 32079a97..cc816f73 100644 --- a/pkg/pipeline/ratelimit/algos/fixedwin.py +++ b/pkg/pipeline/ratelimit/algos/fixedwin.py @@ -72,9 +72,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo): if count >= limitation: if query.pipeline_config['safety']['rate-limit']['strategy'] == 'drop': return False - elif ( - query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait' - ): + elif query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait': # 等待下一窗口 await asyncio.sleep(window_size - time.time() % window_size) diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index 42c141c8..39d3abb1 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -15,9 +15,7 @@ from ...core import entities as core_entities class SendResponseBackStage(stage.PipelineStage): """发送响应消息""" - async def process( - self, query: core_entities.Query, stage_inst_name: str - ) -> entities.StageProcessResult: + async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: """处理""" random_range = ( @@ -34,9 +32,7 @@ class SendResponseBackStage(stage.PipelineStage): if query.pipeline_config['output']['misc']['at-sender'] and isinstance( query.message_event, platform_events.GroupMessage ): - query.resp_message_chain[-1].insert( - 0, platform_message.At(query.message_event.sender.id) - ) + query.resp_message_chain[-1].insert(0, platform_message.At(query.message_event.sender.id)) quote_origin = query.pipeline_config['output']['misc']['quote-origin'] @@ -46,6 +42,4 @@ class SendResponseBackStage(stage.PipelineStage): quote_origin=quote_origin, ) - return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) diff --git a/pkg/pipeline/resprule/resprule.py b/pkg/pipeline/resprule/resprule.py index 99402351..0193f2ce 100644 --- a/pkg/pipeline/resprule/resprule.py +++ b/pkg/pipeline/resprule/resprule.py @@ -32,13 +32,9 @@ class GroupRespondRuleCheckStage(stage.PipelineStage): await rule_inst.initialize() self.rule_matchers.append(rule_inst) - async def process( - self, query: core_entities.Query, stage_inst_name: str - ) -> entities.StageProcessResult: + async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: if query.launcher_type.value != 'group': # 只处理群消息 - return entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) rules = query.pipeline_config['trigger']['group-respond-rules'] @@ -49,9 +45,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage): # use_rule = rules[str(query.launcher_id)] for rule_matcher in self.rule_matchers: # 任意一个匹配就放行 - res = await rule_matcher.match( - str(query.message_chain), query.message_chain, use_rule, query - ) + res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query) if res.matching: query.message_chain = res.replacement @@ -60,6 +54,4 @@ class GroupRespondRuleCheckStage(stage.PipelineStage): new_query=query, ) - return entities.StageProcessResult( - result_type=entities.ResultType.INTERRUPT, new_query=query - ) + return entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query) diff --git a/pkg/pipeline/resprule/rules/atbot.py b/pkg/pipeline/resprule/rules/atbot.py index 0f4845f8..340b92c7 100644 --- a/pkg/pipeline/resprule/rules/atbot.py +++ b/pkg/pipeline/resprule/rules/atbot.py @@ -16,10 +16,7 @@ class AtBotRule(rule_model.GroupRespondRule): rule_dict: dict, query: core_entities.Query, ) -> entities.RuleJudgeResult: - if ( - message_chain.has(platform_message.At(query.adapter.bot_account_id)) - and rule_dict['at'] - ): + if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']: message_chain.remove(platform_message.At(query.adapter.bot_account_id)) if message_chain.has( diff --git a/pkg/pipeline/resprule/rules/random.py b/pkg/pipeline/resprule/rules/random.py index 535bfe6b..d2f782ab 100644 --- a/pkg/pipeline/resprule/rules/random.py +++ b/pkg/pipeline/resprule/rules/random.py @@ -18,6 +18,4 @@ class RandomRespRule(rule_model.GroupRespondRule): ) -> entities.RuleJudgeResult: random_rate = rule_dict['random'] - return entities.RuleJudgeResult( - matching=random.random() < random_rate, replacement=message_chain - ) + return entities.RuleJudgeResult(matching=random.random() < random_rate, replacement=message_chain) diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py index bca02527..3299a226 100644 --- a/pkg/pipeline/wrapper/wrapper.py +++ b/pkg/pipeline/wrapper/wrapper.py @@ -34,29 +34,19 @@ class ResponseWrapper(stage.PipelineStage): if isinstance(query.resp_messages[-1], platform_message.MessageChain): query.resp_message_chain.append(query.resp_messages[-1]) - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) else: if query.resp_messages[-1].role == 'command': query.resp_message_chain.append( - query.resp_messages[-1].get_content_platform_message_chain( - prefix_text='[bot] ' - ) + query.resp_messages[-1].get_content_platform_message_chain(prefix_text='[bot] ') ) - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) elif query.resp_messages[-1].role == 'plugin': - query.resp_message_chain.append( - query.resp_messages[-1].get_content_platform_message_chain() - ) + query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain()) - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, new_query=query - ) + yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) else: if query.resp_messages[-1].role == 'assistant': result = query.resp_messages[-1] @@ -77,9 +67,7 @@ class ResponseWrapper(stage.PipelineStage): prefix='', response_text=reply_text, finish_reason='stop', - funcs_called=[ - fc.function.name for fc in result.tool_calls - ] + funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [], query=query, @@ -92,36 +80,26 @@ class ResponseWrapper(stage.PipelineStage): ) else: if event_ctx.event.reply is not None: - query.resp_message_chain.append( - platform_message.MessageChain(event_ctx.event.reply) - ) + query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply)) else: - query.resp_message_chain.append( - result.get_content_platform_message_chain() - ) + query.resp_message_chain.append(result.get_content_platform_message_chain()) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query, ) - if ( - result.tool_calls is not None and len(result.tool_calls) > 0 - ): # 有函数调用 + if result.tool_calls is not None and len(result.tool_calls) > 0: # 有函数调用 function_names = [tc.function.name for tc in result.tool_calls] reply_text = f'调用函数 {".".join(function_names)}...' query.resp_message_chain.append( - platform_message.MessageChain( - [platform_message.Plain(reply_text)] - ) + platform_message.MessageChain([platform_message.Plain(reply_text)]) ) - if query.pipeline_config['output']['misc'][ - 'track-function-calls' - ]: + if query.pipeline_config['output']['misc']['track-function-calls']: event_ctx = await self.ap.plugin_mgr.emit_event( event=events.NormalMessageResponded( launcher_type=query.launcher_type.value, @@ -131,9 +109,7 @@ class ResponseWrapper(stage.PipelineStage): prefix='', response_text=reply_text, finish_reason='stop', - funcs_called=[ - fc.function.name for fc in result.tool_calls - ] + funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [], query=query, @@ -148,16 +124,12 @@ class ResponseWrapper(stage.PipelineStage): else: if event_ctx.event.reply is not None: query.resp_message_chain.append( - platform_message.MessageChain( - event_ctx.event.reply - ) + platform_message.MessageChain(event_ctx.event.reply) ) else: query.resp_message_chain.append( - platform_message.MessageChain( - [platform_message.Plain(reply_text)] - ) + platform_message.MessageChain([platform_message.Plain(reply_text)]) ) yield entities.StageProcessResult( diff --git a/pkg/platform/adapter.py b/pkg/platform/adapter.py index 61ff32cd..c0fd15c5 100644 --- a/pkg/platform/adapter.py +++ b/pkg/platform/adapter.py @@ -32,9 +32,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): self.config = config self.ap = ap - async def send_message( - self, target_type: str, target_id: str, message: platform_message.MessageChain - ): + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): """主动发送消息 Args: @@ -66,9 +64,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): def register_listener( self, event_type: typing.Type[platform_message.Event], - callback: typing.Callable[ - [platform_message.Event, MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None], ): """注册事件监听器 @@ -81,9 +77,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): def unregister_listener( self, event_type: typing.Type[platform_message.Event], - callback: typing.Callable[ - [platform_message.Event, MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None], ): """注销事件监听器 diff --git a/pkg/platform/botmgr.py b/pkg/platform/botmgr.py index 507f067e..0af7e394 100644 --- a/pkg/platform/botmgr.py +++ b/pkg/platform/botmgr.py @@ -132,14 +132,10 @@ class PlatformManager: self.adapter_dict = {} async def initialize(self): - self.adapter_components = self.ap.discover.get_components_by_kind( - 'MessagePlatformAdapter' - ) + self.adapter_components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter') adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] = {} for component in self.adapter_components: - adapter_dict[component.metadata.name] = ( - component.get_python_component_class() - ) + adapter_dict[component.metadata.name] = component.get_python_component_class() self.adapter_dict = adapter_dict await self.load_bots_from_db() @@ -152,9 +148,7 @@ class PlatformManager: self.bots = [] - result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_bot.Bot) - ) + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot)) bots = result.all() @@ -172,13 +166,9 @@ class PlatformManager: elif isinstance(bot_entity, dict): bot_entity = persistence_bot.Bot(**bot_entity) - adapter_inst = self.adapter_dict[bot_entity.adapter]( - bot_entity.adapter_config, self.ap - ) + adapter_inst = self.adapter_dict[bot_entity.adapter](bot_entity.adapter_config, self.ap) - runtime_bot = RuntimeBot( - ap=self.ap, bot_entity=bot_entity, adapter=adapter_inst - ) + runtime_bot = RuntimeBot(ap=self.ap, bot_entity=bot_entity, adapter=adapter_inst) await runtime_bot.initialize() @@ -209,9 +199,7 @@ class PlatformManager: return component.to_plain_dict() return None - def get_available_adapter_manifest_by_name( - self, name: str - ) -> engine.Component | None: + def get_available_adapter_manifest_by_name(self, name: str) -> engine.Component | None: for component in self.adapter_components: if component.metadata.name == name: return component diff --git a/pkg/platform/sources/aiocqhttp.py b/pkg/platform/sources/aiocqhttp.py index 48116507..bee97f57 100644 --- a/pkg/platform/sources/aiocqhttp.py +++ b/pkg/platform/sources/aiocqhttp.py @@ -58,13 +58,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): msg_list.append(aiocqhttp.MessageSegment.record(msg.path)) elif type(msg) is platform_message.Forward: for node in msg.node_list: - msg_list.extend( - ( - await AiocqhttpMessageConverter.yiri2target( - node.message_chain - ) - )[0] - ) + msg_list.extend((await AiocqhttpMessageConverter.yiri2target(node.message_chain))[0]) else: msg_list.append(aiocqhttp.MessageSegment.text(str(msg))) @@ -77,9 +71,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): yiri_msg_list = [] - yiri_msg_list.append( - platform_message.Source(id=message_id, time=datetime.datetime.now()) - ) + yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now())) for msg in message: if msg.type == 'at': @@ -94,14 +86,8 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): elif msg.type == 'text': yiri_msg_list.append(platform_message.Plain(text=msg.data['text'])) elif msg.type == 'image': - image_base64, image_format = await image.qq_image_url_to_base64( - msg.data['url'] - ) - yiri_msg_list.append( - platform_message.Image( - base64=f'data:image/{image_format};base64,{image_base64}' - ) - ) + image_base64, image_format = await image.qq_image_url_to_base64(msg.data['url']) + yiri_msg_list.append(platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}')) chain = platform_message.MessageChain(yiri_msg_list) @@ -115,9 +101,7 @@ class AiocqhttpEventConverter(adapter.EventConverter): @staticmethod async def target2yiri(event: aiocqhttp.Event): - yiri_chain = await AiocqhttpMessageConverter.target2yiri( - event.message, event.message_id - ) + yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id) if event.message_type == 'group': permission = 'MEMBER' @@ -137,9 +121,7 @@ class AiocqhttpEventConverter(adapter.EventConverter): name=event.sender['nickname'], permission=platform_entities.Permission.Member, ), - special_title=event.sender['title'] - if 'title' in event.sender - else '', + special_title=event.sender['title'] if 'title' in event.sender else '', join_timestamp=0, last_speak_timestamp=0, mute_time_remaining=0, @@ -191,9 +173,7 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): else: self.bot = aiocqhttp.CQHttp() - async def send_message( - self, target_type: str, target_id: str, message: platform_message.MessageChain - ): + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0] if target_type == 'group': @@ -207,14 +187,10 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): message: platform_message.MessageChain, quote_origin: bool = False, ): - aiocq_event = await AiocqhttpEventConverter.yiri2target( - message_source, self.bot_account_id - ) + aiocq_event = await AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id) aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0] if quote_origin: - aiocq_msg = ( - aiocqhttp.MessageSegment.reply(aiocq_event.message_id) + aiocq_msg - ) + aiocq_msg = aiocqhttp.MessageSegment.reply(aiocq_event.message_id) + aiocq_msg return await self.bot.send(aiocq_event, aiocq_msg) @@ -224,16 +200,12 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): async def on_message(event: aiocqhttp.Event): self.bot_account_id = event.self_id try: - return await callback( - await self.event_converter.target2yiri(event), self - ) + return await callback(await self.event_converter.target2yiri(event), self) except Exception: traceback.print_exc() @@ -245,9 +217,7 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): return super().unregister_listener(event_type, callback) diff --git a/pkg/platform/sources/dingtalk.py b/pkg/platform/sources/dingtalk.py index ff9173ef..433ef836 100644 --- a/pkg/platform/sources/dingtalk.py +++ b/pkg/platform/sources/dingtalk.py @@ -22,9 +22,7 @@ class DingTalkMessageConverter(adapter.MessageConverter): async def target2yiri(event: DingTalkEvent, bot_name: str): yiri_msg_list = [] yiri_msg_list.append( - platform_message.Source( - id=event.incoming_message.message_id, time=datetime.datetime.now() - ) + platform_message.Source(id=event.incoming_message.message_id, time=datetime.datetime.now()) ) for atUser in event.incoming_message.at_users: @@ -133,9 +131,7 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): content = await DingTalkMessageConverter.yiri2target(message) await self.bot.send_message(content, incoming_message) - async def send_message( - self, target_type: str, target_id: str, message: platform_message.MessageChain - ): + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): content = await DingTalkMessageConverter.yiri2target(message) if target_type == 'person': await self.bot.send_proactive_message_to_one(target_id, content) @@ -145,16 +141,12 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): async def on_message(event: DingTalkEvent): try: return await callback( - await self.event_converter.target2yiri( - event, self.config['robot_name'] - ), + await self.event_converter.target2yiri(event, self.config['robot_name']), self, ) except Exception: @@ -174,8 +166,6 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): async def unregister_listener( self, event_type: type, - callback: typing.Callable[ - [platform_events.Event, MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], ): return super().unregister_listener(event_type, callback) diff --git a/pkg/platform/sources/discord.py b/pkg/platform/sources/discord.py index 07dd586f..f5be422d 100644 --- a/pkg/platform/sources/discord.py +++ b/pkg/platform/sources/discord.py @@ -45,9 +45,7 @@ class DiscordMessageConverter(adapter.MessageConverter): with open(ele.path, 'rb') as f: image_bytes = f.read() - image_files.append( - discord.File(fp=image_bytes, filename=f'{uuid.uuid4()}.png') - ) + image_files.append(discord.File(fp=image_bytes, filename=f'{uuid.uuid4()}.png')) elif isinstance(ele, platform_message.Plain): text_string += ele.text elif isinstance(ele, platform_message.Forward): @@ -65,9 +63,7 @@ class DiscordMessageConverter(adapter.MessageConverter): async def target2yiri(message: discord.Message) -> platform_message.MessageChain: lb_msg_list = [] - msg_create_time = datetime.datetime.fromtimestamp( - int(message.created_at.timestamp()) - ) + msg_create_time = datetime.datetime.fromtimestamp(int(message.created_at.timestamp())) lb_msg_list.append(platform_message.Source(id=message.id, time=msg_create_time)) @@ -97,11 +93,7 @@ class DiscordMessageConverter(adapter.MessageConverter): else: mid_at_component.append(platform_message.At(target=mid_at[2:-1])) - return ( - text_element_recur(text_split[0]) - + mid_at_component - + text_element_recur(text_split[1]) - ) + return text_element_recur(text_split[0]) + mid_at_component + text_element_recur(text_split[1]) else: return [platform_message.Plain(text=text_ele)] @@ -114,11 +106,7 @@ class DiscordMessageConverter(adapter.MessageConverter): image_data = await response.read() image_base64 = base64.b64encode(image_data).decode('utf-8') image_format = response.headers['Content-Type'] - element_list.append( - platform_message.Image( - base64=f'data:{image_format};base64,{image_base64}' - ) - ) + element_list.append(platform_message.Image(base64=f'data:{image_format};base64,{image_base64}')) return platform_message.MessageChain(element_list) @@ -208,9 +196,7 @@ class DiscordAdapter(adapter.MessagePlatformAdapter): self.bot = MyClient(intents=intents, **args) - async def send_message( - self, target_type: str, target_id: str, message: platform_message.MessageChain - ): + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): pass async def reply_message( @@ -243,18 +229,14 @@ class DiscordAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): self.listeners[event_type] = callback def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): self.listeners.pop(event_type) diff --git a/pkg/platform/sources/gewechat.py b/pkg/platform/sources/gewechat.py index 0555f050..efa58f3d 100644 --- a/pkg/platform/sources/gewechat.py +++ b/pkg/platform/sources/gewechat.py @@ -40,14 +40,10 @@ class GewechatMessageConverter(adapter.MessageConverter): content_list.append({'type': 'image', 'image': component.url}) elif isinstance(component, platform_message.Voice): - content_list.append( - {'type': 'voice', 'url': component.url, 'length': component.length} - ) + content_list.append({'type': 'voice', 'url': component.url, 'length': component.length}) elif isinstance(component, platform_message.Forward): for node in component.node_list: - content_list.extend( - await GewechatMessageConverter.yiri2target(node.message_chain) - ) + content_list.extend(await GewechatMessageConverter.yiri2target(node.message_chain)) content_list.append({'type': 'image', 'image': component.url}) elif isinstance(component, platform_message.WeChatMiniPrograms): content_list.append( @@ -88,44 +84,26 @@ class GewechatMessageConverter(adapter.MessageConverter): } ) elif isinstance(component, platform_message.WeChatForwardLink): - content_list.append( - {'type': 'WeChatForwardLink', 'xml_data': component.xml_data} - ) + content_list.append({'type': 'WeChatForwardLink', 'xml_data': component.xml_data}) elif isinstance(component, platform_message.Voice): - content_list.append( - {'type': 'voice', 'url': component.url, 'length': component.length} - ) + content_list.append({'type': 'voice', 'url': component.url, 'length': component.length}) elif isinstance(component, platform_message.WeChatForwardImage): - content_list.append( - {'type': 'WeChatForwardImage', 'xml_data': component.xml_data} - ) + content_list.append({'type': 'WeChatForwardImage', 'xml_data': component.xml_data}) elif isinstance(component, platform_message.WeChatForwardFile): - content_list.append( - {'type': 'WeChatForwardFile', 'xml_data': component.xml_data} - ) + content_list.append({'type': 'WeChatForwardFile', 'xml_data': component.xml_data}) elif isinstance(component, platform_message.WeChatAppMsg): - content_list.append( - {'type': 'WeChatAppMsg', 'app_msg': component.app_msg} - ) + content_list.append({'type': 'WeChatAppMsg', 'app_msg': component.app_msg}) # 引用消息转发 elif isinstance(component, platform_message.WeChatForwardQuote): - content_list.append( - {'type': 'WeChatAppMsg', 'app_msg': component.app_msg} - ) + content_list.append({'type': 'WeChatAppMsg', 'app_msg': component.app_msg}) elif isinstance(component, platform_message.Forward): for node in component.node_list: if node.message_chain: - content_list.extend( - await GewechatMessageConverter.yiri2target( - node.message_chain - ) - ) + content_list.extend(await GewechatMessageConverter.yiri2target(node.message_chain)) return content_list - async def target2yiri( - self, message: dict, bot_account_id: str - ) -> platform_message.MessageChain: + async def target2yiri(self, message: dict, bot_account_id: str) -> platform_message.MessageChain: """外部消息转平台消息""" # 数据预处理 message_list = [] @@ -163,28 +141,20 @@ class GewechatMessageConverter(adapter.MessageConverter): return platform_message.MessageChain(message_list) - async def _handler_text( - self, message: Optional[dict], content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_text(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理文本消息 (msg_type=1)""" if message and self._is_group_message(message): pattern = r'@\S{1,20}' content_no_preifx = re.sub(pattern, '', content_no_preifx) - return platform_message.MessageChain( - [platform_message.Plain(content_no_preifx)] - ) + return platform_message.MessageChain([platform_message.Plain(content_no_preifx)]) - async def _handler_image( - self, message: Optional[dict], content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_image(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理图像消息 (msg_type=3)""" try: image_xml = content_no_preifx if not image_xml: - return platform_message.MessageChain( - [platform_message.Unknown('[图片内容为空]')] - ) + return platform_message.MessageChain([platform_message.Unknown('[图片内容为空]')]) base64_str, image_format = await image.get_gewechat_image_base64( gewechat_url=self.config['gewechat_url'], @@ -196,21 +166,15 @@ class GewechatMessageConverter(adapter.MessageConverter): ) elements = [ - platform_message.Image( - base64=f'data:image/{image_format};base64,{base64_str}' - ), + platform_message.Image(base64=f'data:image/{image_format};base64,{base64_str}'), platform_message.WeChatForwardImage(xml_data=image_xml), # 微信消息转发 ] return platform_message.MessageChain(elements) except Exception as e: print(f'处理图片失败: {str(e)}') - return platform_message.MessageChain( - [platform_message.Unknown('[图片处理失败]')] - ) + return platform_message.MessageChain([platform_message.Unknown('[图片处理失败]')]) - async def _handler_voice( - self, message: Optional[dict], content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_voice(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理语音消息 (msg_type=34)""" message_List = [] try: @@ -223,9 +187,7 @@ class GewechatMessageConverter(adapter.MessageConverter): return platform_message.MessageChain(message_List) # 转换为平台支持的语音格式(如 Silk 格式) - voice_element = platform_message.Voice( - base64=f'data:audio/silk;base64,{audio_base64}' - ) + voice_element = platform_message.Voice(base64=f'data:audio/silk;base64,{audio_base64}') message_List.append(voice_element) except KeyError as e: @@ -237,9 +199,7 @@ class GewechatMessageConverter(adapter.MessageConverter): return platform_message.MessageChain(message_List) - async def _handler_compound( - self, message: Optional[dict], content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_compound(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理复合消息 (msg_type=49),根据子类型分派""" try: xml_data = ET.fromstring(content_no_preifx) @@ -254,33 +214,21 @@ class GewechatMessageConverter(adapter.MessageConverter): '6': self._handler_compound_file, '33': self._handler_compound_mini_program, '36': self._handler_compound_mini_program, - '2000': partial( - self._handler_compound_unsupported, text='[转账消息]' - ), - '2001': partial( - self._handler_compound_unsupported, text='[红包消息]' - ), - '51': partial( - self._handler_compound_unsupported, text='[视频号消息]' - ), + '2000': partial(self._handler_compound_unsupported, text='[转账消息]'), + '2001': partial(self._handler_compound_unsupported, text='[红包消息]'), + '51': partial(self._handler_compound_unsupported, text='[视频号消息]'), } - handler = sub_handler_map.get( - data_type, self._handler_compound_unsupported - ) + handler = sub_handler_map.get(data_type, self._handler_compound_unsupported) return await handler( message=message, # 原始msg xml_data=xml_data, # xml数据 ) else: - return platform_message.MessageChain( - [platform_message.Unknown(text=content_no_preifx)] - ) + return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)]) except Exception as e: print(f'解析复合消息失败: {str(e)}') - return platform_message.MessageChain( - [platform_message.Unknown(text=content_no_preifx)] - ) + return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)]) async def _handler_compound_quote( self, message: Optional[dict], xml_data: ET.Element @@ -296,9 +244,7 @@ class GewechatMessageConverter(adapter.MessageConverter): user_data = appmsg_data.findtext('.//title') or '' quote_data = appmsg_data.find('.//refermsg').findtext('.//content') message_list.append( - platform_message.WeChatForwardQuote( - app_msg=ET.tostring(appmsg_data, encoding='unicode') - ) + platform_message.WeChatForwardQuote(app_msg=ET.tostring(appmsg_data, encoding='unicode')) ) # quote_data原始的消息 if quote_data: @@ -311,22 +257,14 @@ class GewechatMessageConverter(adapter.MessageConverter): # 引用消息展开 quote_data_xml = ET.fromstring(quote_data) if quote_data_xml.find('img'): - quote_data_message_list.extend( - await self._handler_image(None, quote_data) - ) + quote_data_message_list.extend(await self._handler_image(None, quote_data)) elif quote_data_xml.find('voicemsg'): - quote_data_message_list.extend( - await self._handler_voice(None, quote_data) - ) + quote_data_message_list.extend(await self._handler_voice(None, quote_data)) elif quote_data_xml.find('videomsg'): - quote_data_message_list.extend( - await self._handler_default(None, quote_data) - ) # 先不处理 + quote_data_message_list.extend(await self._handler_default(None, quote_data)) # 先不处理 else: # appmsg - quote_data_message_list.extend( - await self._handler_compound(None, quote_data) - ) + quote_data_message_list.extend(await self._handler_compound(None, quote_data)) except Exception as e: print(f'处理引用消息异常 expcetion:{e}') quote_data_message_list.append(platform_message.Plain(quote_data)) @@ -351,18 +289,12 @@ class GewechatMessageConverter(adapter.MessageConverter): # print(f"quote_message_chain plain [msg_type={comp.type}][message={comp.text}]") return platform_message.MessageChain(message_list) - async def _handler_compound_file( - self, message: dict, xml_data: ET.Element - ) -> platform_message.MessageChain: + async def _handler_compound_file(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain: """处理文件消息 (data_type=6)""" xml_data_str = ET.tostring(xml_data, encoding='unicode') - return platform_message.MessageChain( - [platform_message.WeChatForwardFile(xml_data=xml_data_str)] - ) + return platform_message.MessageChain([platform_message.WeChatForwardFile(xml_data=xml_data_str)]) - async def _handler_compound_link( - self, message: dict, xml_data: ET.Element - ) -> platform_message.MessageChain: + async def _handler_compound_link(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain: """处理链接消息(如公众号文章、外部网页)""" message_list = [] try: @@ -381,9 +313,7 @@ class GewechatMessageConverter(adapter.MessageConverter): # 转发消息 xml_data_str = ET.tostring(xml_data, encoding='unicode') # print(xml_data_str) - message_list.append( - platform_message.WeChatForwardLink(xml_data=xml_data_str) - ) + message_list.append(platform_message.WeChatForwardLink(xml_data=xml_data_str)) except Exception as e: print(f'解析链接消息失败: {str(e)}') return platform_message.MessageChain(message_list) @@ -393,21 +323,15 @@ class GewechatMessageConverter(adapter.MessageConverter): ) -> platform_message.MessageChain: """处理小程序消息(如小程序卡片、服务通知)""" xml_data_str = ET.tostring(xml_data, encoding='unicode') - return platform_message.MessageChain( - [platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)] - ) + return platform_message.MessageChain([platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)]) - async def _handler_default( - self, message: Optional[dict], content_no_preifx: str - ) -> platform_message.MessageChain: + async def _handler_default(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain: """处理未知消息类型""" if message: msg_type = message['Data']['MsgType'] else: msg_type = '' - return platform_message.MessageChain( - [platform_message.Unknown(text=f'[未知消息类型 msg_type:{msg_type}]')] - ) + return platform_message.MessageChain([platform_message.Unknown(text=f'[未知消息类型 msg_type:{msg_type}]')]) def _handler_compound_unsupported( self, message: dict, xml_data: str, text: Optional[str] = None @@ -416,11 +340,7 @@ class GewechatMessageConverter(adapter.MessageConverter): if not text: text = f'[xml_data={xml_data}]' content_list = [] - content_list.append( - platform_message.Unknown( - text=f'[处理未支持复合消息类型[msg_type=49]|{text}' - ) - ) + content_list.append(platform_message.Unknown(text=f'[处理未支持复合消息类型[msg_type=49]|{text}')) return platform_message.MessageChain(content_list) @@ -448,9 +368,7 @@ class GewechatMessageConverter(adapter.MessageConverter): appmsg_data = xml_data.find('.//appmsg') tousername = message['Wxid'] if appmsg_data: # 接收方: 所属微信的wxid - quote_id = appmsg_data.find('.//refermsg').findtext( - './/chatusr' - ) # 引用消息的原发送者 + quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') # 引用消息的原发送者 ats_bot = ats_bot or (quote_id == tousername) except Exception as e: print(f'_ats_bot got except: {e}') @@ -458,9 +376,7 @@ class GewechatMessageConverter(adapter.MessageConverter): return ats_bot # 提取一下content前面的sender_id, 和去掉前缀的内容 - def _extract_content_and_sender( - self, raw_content: str - ) -> Tuple[str, Optional[str]]: + def _extract_content_and_sender(self, raw_content: str) -> Tuple[str, Optional[str]]: try: # 检查消息开头,如果有 wxid_sbitaz0mt65n22:\n 则删掉 # add: 有些用户的wxid不是上述格式。换成user_name: @@ -490,21 +406,17 @@ class GewechatEventConverter(adapter.EventConverter): async def yiri2target(event: platform_events.MessageEvent) -> dict: pass - async def target2yiri( - self, event: dict, bot_account_id: str - ) -> platform_events.MessageEvent: + async def target2yiri(self, event: dict, bot_account_id: str) -> platform_events.MessageEvent: # print(event) # 排除自己发消息回调回答问题 if event['Wxid'] == event['Data']['FromUserName']['string']: return None # 排除公众号以及微信团队消息 - if event['Data']['FromUserName']['string'].startswith('gh_') or event['Data'][ - 'FromUserName' - ]['string'].startswith('weixin'): + if event['Data']['FromUserName']['string'].startswith('gh_') or event['Data']['FromUserName'][ + 'string' + ].startswith('weixin'): return None - message_chain = await self.message_converter.target2yiri( - copy.deepcopy(event), bot_account_id - ) + message_chain = await self.message_converter.target2yiri(copy.deepcopy(event), bot_account_id) if not message_chain: return None @@ -589,9 +501,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): return 'ok' elif 'TypeName' in data and data['TypeName'] == 'AddMsg': try: - event = await self.event_converter.target2yiri( - data.copy(), self.bot_account_id - ) + event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id) except Exception: traceback.print_exc() @@ -600,9 +510,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): return 'ok' - async def _handle_message( - self, message: platform_message.MessageChain, target_id: str - ): + async def _handle_message(self, message: platform_message.MessageChain, target_id: str): """统一消息处理核心逻辑""" content_list = await self.message_converter.yiri2target(message) at_targets = [item['target'] for item in content_list if item['type'] == 'at'] @@ -611,9 +519,9 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): at_targets = at_targets or [] member_info = [] if at_targets: - member_info = self.bot.get_chatroom_member_detail( - self.config['app_id'], target_id, at_targets[::-1] - )['data'] + member_info = self.bot.get_chatroom_member_detail(self.config['app_id'], target_id, at_targets[::-1])[ + 'data' + ] # 处理消息组件 for msg in content_list: @@ -694,9 +602,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}') continue - async def send_message( - self, target_type: str, target_id: str, message: platform_message.MessageChain - ): + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): """主动发送消息""" return await self._handle_message(message, target_id) @@ -708,9 +614,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): ): """回复消息""" if message_source.source_platform_object: - target_id = message_source.source_platform_object['Data']['FromUserName'][ - 'string' - ] + target_id = message_source.source_platform_object['Data']['FromUserName']['string'] return await self._handle_message(message, target_id) async def is_muted(self, group_id: int) -> bool: @@ -719,18 +623,14 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): self.listeners[event_type] = callback def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): pass @@ -742,14 +642,10 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): json={'app_id': self.config['app_id']}, ) as response: if response.status != 200: - raise Exception( - f'获取gewechat token失败: {await response.text()}' - ) + raise Exception(f'获取gewechat token失败: {await response.text()}') self.config['token'] = (await response.json())['data'] - self.bot = gewechat_client.GewechatClient( - f'{self.config["gewechat_url"]}/v2/api', self.config['token'] - ) + self.bot = gewechat_client.GewechatClient(f'{self.config["gewechat_url"]}/v2/api', self.config['token']) def gewechat_login_process(): app_id, error_msg = self.bot.login(self.config['app_id']) diff --git a/pkg/platform/sources/lark.py b/pkg/platform/sources/lark.py index ae6e89ee..0bf19a23 100644 --- a/pkg/platform/sources/lark.py +++ b/pkg/platform/sources/lark.py @@ -71,14 +71,10 @@ class LarkMessageConverter(adapter.MessageConverter): pending_paragraph.append({'tag': 'md', 'text': text}) except UnicodeError: # If still fails, replace invalid characters - text = msg.text.encode('utf-8', errors='replace').decode( - 'utf-8' - ) + text = msg.text.encode('utf-8', errors='replace').decode('utf-8') pending_paragraph.append({'tag': 'md', 'text': text}) elif isinstance(msg, platform_message.At): - pending_paragraph.append( - {'tag': 'at', 'user_id': msg.target, 'style': []} - ) + pending_paragraph.append({'tag': 'at', 'user_id': msg.target, 'style': []}) elif isinstance(msg, platform_message.AtAll): pending_paragraph.append({'tag': 'at', 'user_id': 'all', 'style': []}) elif isinstance(msg, platform_message.Image): @@ -166,11 +162,7 @@ class LarkMessageConverter(adapter.MessageConverter): os.unlink(temp_file.name) elif isinstance(msg, platform_message.Forward): for node in msg.node_list: - message_elements.extend( - await LarkMessageConverter.yiri2target( - node.message_chain, api_client - ) - ) + message_elements.extend(await LarkMessageConverter.yiri2target(node.message_chain, api_client)) if pending_paragraph: message_elements.append(pending_paragraph) @@ -186,13 +178,9 @@ class LarkMessageConverter(adapter.MessageConverter): lb_msg_list = [] - msg_create_time = datetime.datetime.fromtimestamp( - int(message.create_time) / 1000 - ) + msg_create_time = datetime.datetime.fromtimestamp(int(message.create_time) / 1000) - lb_msg_list.append( - platform_message.Source(id=message.message_id, time=msg_create_time) - ) + lb_msg_list.append(platform_message.Source(id=message.message_id, time=msg_create_time)) if message.message_type == 'text': element_list = [] @@ -222,9 +210,7 @@ class LarkMessageConverter(adapter.MessageConverter): left_text = text_split[0] right_text = text_split[1] - new_list.extend( - text_element_recur({'tag': 'text', 'text': left_text, 'style': []}) - ) + new_list.extend(text_element_recur({'tag': 'text', 'text': left_text, 'style': []})) new_list.append( { @@ -235,15 +221,11 @@ class LarkMessageConverter(adapter.MessageConverter): } ) - new_list.extend( - text_element_recur({'tag': 'text', 'text': right_text, 'style': []}) - ) + new_list.extend(text_element_recur({'tag': 'text', 'text': right_text, 'style': []})) return new_list - element_list = text_element_recur( - {'tag': 'text', 'text': message_content['text'], 'style': []} - ) + element_list = text_element_recur({'tag': 'text', 'text': message_content['text'], 'style': []}) message_content = {'title': '', 'content': element_list} @@ -258,9 +240,7 @@ class LarkMessageConverter(adapter.MessageConverter): message_content['content'] = new_list elif message.message_type == 'image': - message_content['content'] = [ - {'tag': 'img', 'image_key': message_content['image_key'], 'style': []} - ] + message_content['content'] = [{'tag': 'img', 'image_key': message_content['image_key'], 'style': []}] for ele in message_content['content']: if ele['tag'] == 'text': @@ -278,9 +258,7 @@ class LarkMessageConverter(adapter.MessageConverter): .build() ) - response: GetMessageResourceResponse = ( - await api_client.im.v1.message_resource.aget(request) - ) + response: GetMessageResourceResponse = await api_client.im.v1.message_resource.aget(request) if not response.success(): raise Exception( @@ -292,11 +270,7 @@ class LarkMessageConverter(adapter.MessageConverter): image_format = response.raw.headers['content-type'] - lb_msg_list.append( - platform_message.Image( - base64=f'data:{image_format};base64,{image_base64}' - ) - ) + lb_msg_list.append(platform_message.Image(base64=f'data:{image_format};base64,{image_base64}')) return platform_message.MessageChain(lb_msg_list) @@ -312,9 +286,7 @@ class LarkEventConverter(adapter.EventConverter): async def target2yiri( event: lark_oapi.im.v1.P2ImMessageReceiveV1, api_client: lark_oapi.Client ) -> platform_events.Event: - message_chain = await LarkMessageConverter.target2yiri( - event.event.message, api_client - ) + message_chain = await LarkMessageConverter.target2yiri(event.event.message, api_client) if event.event.message.chat_type == 'p2p': return platform_events.FriendMessage( @@ -402,9 +374,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter): p2v1.schema = context.schema if 'im.message.receive_v1' == type: try: - event = await self.event_converter.target2yiri( - p2v1, self.api_client - ) + event = await self.event_converter.target2yiri(p2v1, self.api_client) except Exception: traceback.print_exc() @@ -425,26 +395,15 @@ class LarkAdapter(adapter.MessagePlatformAdapter): asyncio.create_task(on_message(event)) event_handler = ( - lark_oapi.EventDispatcherHandler.builder('', '') - .register_p2_im_message_receive_v1(sync_on_message) - .build() + lark_oapi.EventDispatcherHandler.builder('', '').register_p2_im_message_receive_v1(sync_on_message).build() ) self.bot_account_id = config['bot_name'] - self.bot = lark_oapi.ws.Client( - config['app_id'], config['app_secret'], event_handler=event_handler - ) - self.api_client = ( - lark_oapi.Client.builder() - .app_id(config['app_id']) - .app_secret(config['app_secret']) - .build() - ) + self.bot = lark_oapi.ws.Client(config['app_id'], config['app_secret'], event_handler=event_handler) + self.api_client = lark_oapi.Client.builder().app_id(config['app_id']).app_secret(config['app_secret']).build() - async def send_message( - self, target_type: str, target_id: str, message: platform_message.MessageChain - ): + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): pass async def reply_message( @@ -455,9 +414,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter): ): # 不再需要了,因为message_id已经被包含到message_chain中 # lark_event = await self.event_converter.yiri2target(message_source) - lark_message = await self.message_converter.yiri2target( - message, self.api_client - ) + lark_message = await self.message_converter.yiri2target(message, self.api_client) final_content = { 'zh_cn': { @@ -480,9 +437,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter): .build() ) - response: ReplyMessageResponse = await self.api_client.im.v1.message.areply( - request - ) + response: ReplyMessageResponse = await self.api_client.im.v1.message.areply(request) if not response.success(): raise Exception( @@ -495,18 +450,14 @@ class LarkAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): self.listeners[event_type] = callback def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): self.listeners.pop(event_type) diff --git a/pkg/platform/sources/nakuru.py b/pkg/platform/sources/nakuru.py index 7038af1d..44e2d301 100644 --- a/pkg/platform/sources/nakuru.py +++ b/pkg/platform/sources/nakuru.py @@ -29,9 +29,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter): elif type(message_chain) is str: msg_list = [platform_message.Plain(message_chain)] else: - raise Exception( - 'Unknown message type: ' + str(message_chain) + str(type(message_chain)) - ) + raise Exception('Unknown message type: ' + str(message_chain) + str(type(message_chain))) nakuru_msg_list = [] @@ -63,9 +61,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter): # 遍历并转换 for yiri_forward_node in yiri_forward_node_list: try: - content_list = NakuruProjectMessageConverter.yiri2target( - yiri_forward_node.message_chain - ) + content_list = NakuruProjectMessageConverter.yiri2target(yiri_forward_node.message_chain) nakuru_forward_node = nkc.Node( name=yiri_forward_node.sender_name, uin=yiri_forward_node.sender_id, @@ -87,9 +83,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter): return nakuru_msg_list @staticmethod - def target2yiri( - message_chain: typing.Any, message_id: int = -1 - ) -> platform_message.MessageChain: + def target2yiri(message_chain: typing.Any, message_id: int = -1) -> platform_message.MessageChain: """将Yiri的消息链转换为YiriMirai的消息链""" assert type(message_chain) is list @@ -97,9 +91,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter): import datetime # 添加Source组件以标记message_id等信息 - yiri_msg_list.append( - platform_message.Source(id=message_id, time=datetime.datetime.now()) - ) + yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now())) for component in message_chain: if type(component) is nkc.Plain: yiri_msg_list.append(platform_message.Plain(text=component.text)) @@ -130,9 +122,7 @@ class NakuruProjectEventConverter(adapter_model.EventConverter): @staticmethod def target2yiri(event: typing.Any) -> platform_events.Event: - yiri_chain = NakuruProjectMessageConverter.target2yiri( - event.message, event.message_id - ) + yiri_chain = NakuruProjectMessageConverter.target2yiri(event.message, event.message_id) if type(event) is nakuru.FriendMessage: # 私聊消息事件 return platform_events.FriendMessage( sender=platform_entities.Friend( @@ -206,9 +196,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): ): task = None - converted_msg = ( - self.message_converter.yiri2target(message) if not converted else message - ) + converted_msg = self.message_converter.yiri2target(message) if not converted else message # 检查是否有转发消息 has_forward = False @@ -250,13 +238,9 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): ), ) if type(message_source) is platform_events.GroupMessage: - await self.send_message( - 'group', message_source.sender.group.id, message, converted=True - ) + await self.send_message('group', message_source.sender.group.id, message, converted=True) elif type(message_source) is platform_events.FriendMessage: - await self.send_message( - 'person', message_source.sender.id, message, converted=True - ) + await self.send_message('person', message_source.sender.id, message, converted=True) else: raise Exception('Unknown message source type: ' + str(type(message_source))) @@ -264,17 +248,13 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): import time # 检查是否被禁言 - group_member_info = asyncio.run( - self.bot.getGroupMemberInfo(group_id, self.bot_account_id) - ) + group_member_info = asyncio.run(self.bot.getGroupMemberInfo(group_id, self.bot_account_id)) return group_member_info.shut_up_timestamp > int(time.time()) def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter_model.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None], ): try: source_cls = NakuruProjectEventConverter.yiri2target(event_type) @@ -301,9 +281,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter_model.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None], ): nakuru_event_name = self.event_converter.yiri2target(event_type).__name__ @@ -312,10 +290,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): # 从本对象的监听器列表中查找并删除 target_wrapper = None for listener in self.listener_list: - if ( - listener['event_type'] == event_type - and listener['callable'] == callback - ): + if listener['event_type'] == event_type and listener['callable'] == callback: target_wrapper = listener['wrapper'] self.listener_list.remove(listener) break @@ -334,14 +309,8 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): import requests resp = requests.get( - url='http://{}:{}/get_login_info'.format( - self.cfg['host'], self.cfg['http_port'] - ), - headers={ - 'Authorization': 'Bearer ' + self.cfg['token'] - if 'token' in self.cfg - else '' - }, + url='http://{}:{}/get_login_info'.format(self.cfg['host'], self.cfg['http_port']), + headers={'Authorization': 'Bearer ' + self.cfg['token'] if 'token' in self.cfg else ''}, timeout=5, proxies=None, ) @@ -349,9 +318,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): raise Exception('go-cqhttp拒绝访问,请检查配置文件中nakuru适配器的配置') self.bot_account_id = int(resp.json()['data']['user_id']) except Exception: - raise Exception( - '获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确' - ) + raise Exception('获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确') await self.bot._run() self.ap.logger.info('运行 Nakuru 适配器') while True: diff --git a/pkg/platform/sources/officialaccount.py b/pkg/platform/sources/officialaccount.py index 6e7eaf2f..8c7831a5 100644 --- a/pkg/platform/sources/officialaccount.py +++ b/pkg/platform/sources/officialaccount.py @@ -25,9 +25,7 @@ class OAMessageConverter(adapter.MessageConverter): @staticmethod async def target2yiri(message: str, message_id=-1): yiri_msg_list = [] - yiri_msg_list.append( - platform_message.Source(id=message_id, time=datetime.datetime.now()) - ) + yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now())) yiri_msg_list.append(platform_message.Plain(text=message)) chain = platform_message.MessageChain(yiri_msg_list) @@ -39,9 +37,7 @@ class OAEventConverter(adapter.EventConverter): @staticmethod async def target2yiri(event: OAEvent): if event.type == 'text': - yiri_chain = await OAMessageConverter.target2yiri( - event.message, event.message_id - ) + yiri_chain = await OAMessageConverter.target2yiri(event.message, event.message_id) friend = platform_entities.Friend( id=event.user_id, @@ -81,9 +77,7 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter): ] missing_keys = [key for key in required_keys if key not in config] if missing_keys: - raise ParamNotEnoughError( - '微信公众号缺少相关配置项,请查看文档或联系管理员' - ) + raise ParamNotEnoughError('微信公众号缺少相关配置项,请查看文档或联系管理员') if self.config['Mode'] == 'drop': self.bot = OAClient( @@ -114,28 +108,20 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter): await self.bot.set_message(message_source.message_chain.message_id, content) elif isinstance(self.bot, OAClientForLongerResponse): from_user = message_source.sender.id - await self.bot.set_message( - from_user, message_source.message_chain.message_id, content - ) + await self.bot.set_message(from_user, message_source.message_chain.message_id, content) - async def send_message( - self, target_type: str, target_id: str, message: platform_message.MessageChain - ): + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): pass def register_listener( self, event_type: type, - callback: typing.Callable[ - [platform_events.Event, MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], ): async def on_message(event: OAEvent): self.bot_account_id = event.receiver_id try: - return await callback( - await self.event_converter.target2yiri(event), self - ) + return await callback(await self.event_converter.target2yiri(event), self) except Exception: traceback.print_exc() @@ -161,8 +147,6 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter): async def unregister_listener( self, event_type: type, - callback: typing.Callable[ - [platform_events.Event, MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], ): return super().unregister_listener(event_type, callback) diff --git a/pkg/platform/sources/qqbotpy.py b/pkg/platform/sources/qqbotpy.py index a91f86dd..74699961 100644 --- a/pkg/platform/sources/qqbotpy.py +++ b/pkg/platform/sources/qqbotpy.py @@ -147,9 +147,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter): elif type(message_chain) is str: msg_list = [platform_message.Plain(text=message_chain)] else: - raise Exception( - 'Unknown message type: ' + str(message_chain) + str(type(message_chain)) - ) + raise Exception('Unknown message type: ' + str(message_chain) + str(type(message_chain))) offcial_messages: list[dict] = [] """ @@ -172,19 +170,13 @@ class OfficialMessageConverter(adapter_model.MessageConverter): if component.url is not None: offcial_messages.append({'type': 'image', 'content': component.url}) elif component.path is not None: - offcial_messages.append( - {'type': 'file_image', 'content': component.path} - ) + offcial_messages.append({'type': 'file_image', 'content': component.path}) elif type(component) is platform_message.At: offcial_messages.append({'type': 'at', 'content': ''}) elif type(component) is platform_message.AtAll: - print( - '上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。' - ) + print('上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。') elif type(component) is platform_message.Voice: - print( - '上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。' - ) + print('上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。') elif type(component) is forward.Forward: # 转发消息 yiri_forward_node_list = component.node_list @@ -195,9 +187,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter): message_chain = yiri_forward_node.message_chain # 平铺 - offcial_messages.extend( - OfficialMessageConverter.yiri2target(message_chain) - ) + offcial_messages.extend(OfficialMessageConverter.yiri2target(message_chain)) except Exception: import traceback @@ -219,11 +209,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter): yiri_msg_list = [] # 存id - yiri_msg_list.append( - platform_message.Source( - id=save_msg_id(message_id), time=datetime.datetime.now() - ) - ) + yiri_msg_list.append(platform_message.Source(id=save_msg_id(message_id), time=datetime.datetime.now())) if type(message) not in [botpy_message.DirectMessage, botpy_message.C2CMessage]: yiri_msg_list.append(platform_message.At(target=bot_account_id)) @@ -239,9 +225,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter): if attachment.content_type.startswith('image'): yiri_msg_list.append(platform_message.Image(url=attachment.url)) else: - logging.warning( - '不支持的附件类型:' + attachment.content_type + ',忽略此附件。' - ) + logging.warning('不支持的附件类型:' + attachment.content_type + ',忽略此附件。') content = re.sub(r'<@!\d+>', '', str(message.content)) if content.strip() != '': @@ -264,9 +248,7 @@ class OfficialEventConverter(adapter_model.EventConverter): elif event == platform_events.FriendMessage: return botpy_message.DirectMessage else: - raise Exception( - '未支持转换的事件类型(YiriMirai -> Official): ' + str(event) - ) + raise Exception('未支持转换的事件类型(YiriMirai -> Official): ' + str(event)) def target2yiri( self, @@ -297,21 +279,13 @@ class OfficialEventConverter(adapter_model.EventConverter): ), special_title='', join_timestamp=int( - datetime.datetime.strptime( - event.member.joined_at, '%Y-%m-%dT%H:%M:%S%z' - ).timestamp() + datetime.datetime.strptime(event.member.joined_at, '%Y-%m-%dT%H:%M:%S%z').timestamp() ), last_speak_timestamp=datetime.datetime.now().timestamp(), mute_time_remaining=0, ), - message_chain=OfficialMessageConverter.extract_message_chain_from_obj( - event, event.id - ), - time=int( - datetime.datetime.strptime( - event.timestamp, '%Y-%m-%dT%H:%M:%S%z' - ).timestamp() - ), + message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id), + time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()), ) elif isinstance(event, botpy_message.DirectMessage): # 频道私聊,转私聊事件 return platform_events.FriendMessage( @@ -320,14 +294,8 @@ class OfficialEventConverter(adapter_model.EventConverter): nickname=event.author.username, remark=event.author.username, ), - message_chain=OfficialMessageConverter.extract_message_chain_from_obj( - event, event.id - ), - time=int( - datetime.datetime.strptime( - event.timestamp, '%Y-%m-%dT%H:%M:%S%z' - ).timestamp() - ), + message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id), + time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()), ) elif isinstance(event, botpy_message.GroupMessage): # 群聊,转群聊事件 author_member_id = event.author.member_openid @@ -347,14 +315,8 @@ class OfficialEventConverter(adapter_model.EventConverter): last_speak_timestamp=datetime.datetime.now().timestamp(), mute_time_remaining=0, ), - message_chain=OfficialMessageConverter.extract_message_chain_from_obj( - event, event.id - ), - time=int( - datetime.datetime.strptime( - event.timestamp, '%Y-%m-%dT%H:%M:%S%z' - ).timestamp() - ), + message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id), + time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()), ) elif isinstance(event, botpy_message.C2CMessage): # 私聊,转私聊事件 user_id_alter = event.author.user_openid @@ -365,14 +327,8 @@ class OfficialEventConverter(adapter_model.EventConverter): nickname=user_id_alter, remark=user_id_alter, ), - message_chain=OfficialMessageConverter.extract_message_chain_from_obj( - event, event.id - ), - time=int( - datetime.datetime.strptime( - event.timestamp, '%Y-%m-%dT%H:%M:%S%z' - ).timestamp() - ), + message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id), + time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()), ) @@ -420,9 +376,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): self.bot = botpy.Client(intents=intents) - async def send_message( - self, target_type: str, target_id: str, message: platform_message.MessageChain - ): + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): message_list = self.message_converter.yiri2target(message) for msg in message_list: @@ -468,22 +422,16 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): if quote_origin: args['message_reference'] = botpy_message_type.Reference( - message_id=cached_message_ids[ - str(message_source.message_chain.message_id) - ] + message_id=cached_message_ids[str(message_source.message_chain.message_id)] ) if isinstance(message_source, platform_events.GroupMessage): args['channel_id'] = str(message_source.sender.group.id) - args['msg_id'] = cached_message_ids[ - str(message_source.message_chain.message_id) - ] + args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)] await self.bot.api.post_message(**args) elif isinstance(message_source, platform_events.FriendMessage): args['guild_id'] = str(message_source.sender.id) - args['msg_id'] = cached_message_ids[ - str(message_source.message_chain.message_id) - ] + args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)] await self.bot.api.post_dms(**args) elif isinstance(message_source, OfficialGroupMessage): if 'file_image' in args: # 暂不支持发送文件图片 @@ -502,9 +450,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): args['media'] = uploadMedia args['msg_type'] = 7 - args['msg_id'] = cached_message_ids[ - str(message_source.message_chain.message_id) - ] + args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)] args['msg_seq'] = self.group_msg_seq self.group_msg_seq += 1 @@ -523,9 +469,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): args['media'] = uploadMedia args['msg_type'] = 7 - args['msg_id'] = cached_message_ids[ - str(message_source.message_chain.message_id) - ] + args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)] args['msg_seq'] = self.c2c_msg_seq self.c2c_msg_seq += 1 @@ -538,9 +482,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter_model.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None], ): try: @@ -563,9 +505,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter_model.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None], ): delattr(self.bot, event_handler_mapping[event_type]) diff --git a/pkg/platform/sources/qqofficial.py b/pkg/platform/sources/qqofficial.py index 06893485..f9795bcd 100644 --- a/pkg/platform/sources/qqofficial.py +++ b/pkg/platform/sources/qqofficial.py @@ -35,13 +35,9 @@ class QQOfficialMessageConverter(adapter.MessageConverter): @staticmethod async def target2yiri(message: str, message_id: str, pic_url: str, content_type): yiri_msg_list = [] - yiri_msg_list.append( - platform_message.Source(id=message_id, time=datetime.datetime.now()) - ) + yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now())) if pic_url is not None: - base64_url = await image.get_qq_official_image_base64( - pic_url=pic_url, content_type=content_type - ) + base64_url = await image.get_qq_official_image_base64(pic_url=pic_url, content_type=content_type) yiri_msg_list.append(platform_message.Image(base64=base64_url)) yiri_msg_list.append(platform_message.Plain(text=message)) @@ -75,11 +71,7 @@ class QQOfficialEventConverter(adapter.EventConverter): return platform_events.FriendMessage( sender=friend, message_chain=yiri_chain, - time=int( - datetime.datetime.strptime( - event.timestamp, '%Y-%m-%dT%H:%M:%S%z' - ).timestamp() - ), + time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()), source_platform_object=event, ) @@ -89,9 +81,7 @@ class QQOfficialEventConverter(adapter.EventConverter): nickname=event.t, remark='', ) - return platform_events.FriendMessage( - sender=friend, message_chain=yiri_chain, source_platform_object=event - ) + return platform_events.FriendMessage(sender=friend, message_chain=yiri_chain, source_platform_object=event) if event.t == 'GROUP_AT_MESSAGE_CREATE': yiri_chain.insert(0, platform_message.At(target='justbot')) @@ -109,11 +99,7 @@ class QQOfficialEventConverter(adapter.EventConverter): last_speak_timestamp=0, mute_time_remaining=0, ) - time = int( - datetime.datetime.strptime( - event.timestamp, '%Y-%m-%dT%H:%M:%S%z' - ).timestamp() - ) + time = int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()) return platform_events.GroupMessage( sender=sender, message_chain=yiri_chain, @@ -136,11 +122,7 @@ class QQOfficialEventConverter(adapter.EventConverter): last_speak_timestamp=0, mute_time_remaining=0, ) - time = int( - datetime.datetime.strptime( - event.timestamp, '%Y-%m-%dT%H:%M:%S%z' - ).timestamp() - ) + time = int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()) return platform_events.GroupMessage( sender=sender, message_chain=yiri_chain, @@ -167,9 +149,7 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter): ] missing_keys = [key for key in required_keys if key not in config] if missing_keys: - raise ParamNotEnoughError( - 'QQ官方机器人缺少相关配置项,请查看文档或联系管理员' - ) + raise ParamNotEnoughError('QQ官方机器人缺少相关配置项,请查看文档或联系管理员') self.bot = QQOfficialClient( app_id=config['appid'], @@ -229,24 +209,18 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter): qq_official_event.d_id, ) - async def send_message( - self, target_type: str, target_id: str, message: platform_message.MessageChain - ): + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): pass def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): async def on_message(event: QQOfficialEvent): self.bot_account_id = 'justbot' try: - return await callback( - await self.event_converter.target2yiri(event), self - ) + return await callback(await self.event_converter.target2yiri(event), self) except Exception: traceback.print_exc() @@ -274,8 +248,6 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter): def unregister_listener( self, event_type: type, - callback: typing.Callable[ - [platform_events.Event, MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], ): return super().unregister_listener(event_type, callback) diff --git a/pkg/platform/sources/slack.py b/pkg/platform/sources/slack.py index bc4e4d8e..62ef4137 100644 --- a/pkg/platform/sources/slack.py +++ b/pkg/platform/sources/slack.py @@ -11,37 +11,32 @@ from pkg.platform.types import events as platform_events, message as platform_me from libs.slack_api.slackevent import SlackEvent from pkg.core import app from .. import adapter -from ...core import app -from ..types import message as platform_message -from ..types import events as platform_events from ..types import entities as platform_entities from ...command.errors import ParamNotEnoughError from ...utils import image + class SlackMessageConverter(adapter.MessageConverter): - @staticmethod - async def yiri2target(message_chain:platform_message.MessageChain): + async def yiri2target(message_chain: platform_message.MessageChain): content_list = [] for msg in message_chain: if type(msg) is platform_message.Plain: - content_list.append({ - "content":msg.text, - }) - + content_list.append( + { + 'content': msg.text, + } + ) + return content_list @staticmethod - async def target2yiri(message:str,message_id:str,pic_url:str,bot:SlackClient): + async def target2yiri(message: str, message_id: str, pic_url: str, bot: SlackClient): yiri_msg_list = [] - yiri_msg_list.append( - platform_message.Source(id=message_id,time=datetime.datetime.now()) - ) + yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now())) if pic_url is not None: - base64_url = await image.get_slack_image_to_base64(pic_url=pic_url,bot_token=bot.bot_token) - yiri_msg_list.append( - platform_message.Image(base64=base64_url) - ) + base64_url = await image.get_slack_image_to_base64(pic_url=pic_url, bot_token=bot.bot_token) + yiri_msg_list.append(platform_message.Image(base64=base64_url)) yiri_msg_list.append(platform_message.Plain(text=message)) chain = platform_message.MessageChain(yiri_msg_list) @@ -49,55 +44,43 @@ class SlackMessageConverter(adapter.MessageConverter): class SlackEventConverter(adapter.EventConverter): - @staticmethod - async def yiri2target(event:platform_events.MessageEvent) -> SlackEvent: + async def yiri2target(event: platform_events.MessageEvent) -> SlackEvent: return event.source_platform_object - + @staticmethod - async def target2yiri(event:SlackEvent,bot:SlackClient): + async def target2yiri(event: SlackEvent, bot: SlackClient): yiri_chain = await SlackMessageConverter.target2yiri( - message=event.text,message_id=event.message_id,pic_url=event.pic_url,bot=bot + message=event.text, message_id=event.message_id, pic_url=event.pic_url, bot=bot ) if event.type == 'channel': - yiri_chain.insert(0, platform_message.At(target="SlackBot")) + yiri_chain.insert(0, platform_message.At(target='SlackBot')) sender = platform_entities.GroupMember( - id = event.user_id, - member_name= str(event.sender_name), - permission= 'MEMBER', - group = platform_entities.Group( - id = event.channel_id, - name = 'MEMBER', - permission= platform_entities.Permission.Member + id=event.user_id, + member_name=str(event.sender_name), + permission='MEMBER', + group=platform_entities.Group( + id=event.channel_id, name='MEMBER', permission=platform_entities.Permission.Member ), special_title='', join_timestamp=0, last_speak_timestamp=0, - mute_time_remaining=0 + mute_time_remaining=0, ) time = int(datetime.datetime.utcnow().timestamp()) return platform_events.GroupMessage( - sender = sender, - message_chain=yiri_chain, - time = time, - source_platform_object=event + sender=sender, message_chain=yiri_chain, time=time, source_platform_object=event ) if event.type == 'im': return platform_events.FriendMessage( - sender=platform_entities.Friend( - id=event.user_id, - nickname = event.sender_name, - remark="" - ), - message_chain = yiri_chain, - time = float(datetime.datetime.now().timestamp()), + sender=platform_entities.Friend(id=event.user_id, nickname=event.sender_name, remark=''), + message_chain=yiri_chain, + time=float(datetime.datetime.now().timestamp()), source_platform_object=event, ) - - class SlackAdapter(adapter.MessagePlatformAdapter): @@ -108,21 +91,18 @@ class SlackAdapter(adapter.MessagePlatformAdapter): event_converter: SlackEventConverter = SlackEventConverter() config: dict - def __init__(self,config:dict,ap:app.Application): + def __init__(self, config: dict, ap: app.Application): self.config = config self.ap = ap required_keys = [ - "bot_token", - "signing_secret", + 'bot_token', + 'signing_secret', ] missing_keys = [key for key in required_keys if key not in config] if missing_keys: - raise ParamNotEnoughError("Slack机器人缺少相关配置项,请查看文档或联系管理员") + raise ParamNotEnoughError('Slack机器人缺少相关配置项,请查看文档或联系管理员') - self.bot = SlackClient( - bot_token=self.config["bot_token"], - signing_secret=self.config["signing_secret"] - ) + self.bot = SlackClient(bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret']) async def reply_message( self, @@ -130,52 +110,40 @@ class SlackAdapter(adapter.MessagePlatformAdapter): message: platform_message.MessageChain, quote_origin: bool = False, ): - slack_event = await SlackEventConverter.yiri2target( - message_source - ) + slack_event = await SlackEventConverter.yiri2target(message_source) - content_list = await SlackMessageConverter.yiri2target(message) + content_list = await SlackMessageConverter.yiri2target(message) for content in content_list: if slack_event.type == 'channel': - await self.bot.send_message_to_channel( - content['content'],slack_event.channel_id - ) + await self.bot.send_message_to_channel(content['content'], slack_event.channel_id) if slack_event.type == 'im': - await self.bot.send_message_to_one( - content['content'],slack_event.user_id - ) - + await self.bot.send_message_to_one(content['content'], slack_event.user_id) + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): content_list = await SlackMessageConverter.yiri2target(message) for content in content_list: if target_type == 'person': - await self.bot.send_message_to_one(content['content'],target_id) + await self.bot.send_message_to_one(content['content'], target_id) if target_type == 'group': - await self.bot.send_message_to_channel(content['content'],target_id) - + await self.bot.send_message_to_channel(content['content'], target_id) def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): - async def on_message(event:SlackEvent): + async def on_message(event: SlackEvent): self.bot_account_id = 'SlackBot' try: - return await callback( - await self.event_converter.target2yiri(event,self.bot),self - ) + return await callback(await self.event_converter.target2yiri(event, self.bot), self) except: traceback.print_exc() - - if event_type == platform_events.FriendMessage: - self.bot.on_message("im")(on_message) - elif event_type == platform_events.GroupMessage: - self.bot.on_message("channel")(on_message) + if event_type == platform_events.FriendMessage: + self.bot.on_message('im')(on_message) + elif event_type == platform_events.GroupMessage: + self.bot.on_message('channel')(on_message) async def run_async(self): async def shutdown_trigger_placeholder(): @@ -183,8 +151,8 @@ class SlackAdapter(adapter.MessagePlatformAdapter): await asyncio.sleep(1) await self.bot.run_task( - host="0.0.0.0", - port=self.config["port"], + host='0.0.0.0', + port=self.config['port'], shutdown_trigger=shutdown_trigger_placeholder, ) @@ -197,8 +165,3 @@ class SlackAdapter(adapter.MessagePlatformAdapter): callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], ): return super().unregister_listener(event_type, callback) - - - - - diff --git a/pkg/platform/sources/telegram.py b/pkg/platform/sources/telegram.py index 584d77f3..5d318cbb 100644 --- a/pkg/platform/sources/telegram.py +++ b/pkg/platform/sources/telegram.py @@ -21,9 +21,7 @@ from ..types import entities as platform_entities class TelegramMessageConverter(adapter.MessageConverter): @staticmethod - async def yiri2target( - message_chain: platform_message.MessageChain, bot: telegram.Bot - ) -> list[dict]: + async def yiri2target(message_chain: platform_message.MessageChain, bot: telegram.Bot) -> list[dict]: components = [] for component in message_chain: @@ -45,18 +43,12 @@ class TelegramMessageConverter(adapter.MessageConverter): components.append({'type': 'photo', 'photo': photo_bytes}) elif isinstance(component, platform_message.Forward): for node in component.node_list: - components.extend( - await TelegramMessageConverter.yiri2target( - node.message_chain, bot - ) - ) + components.extend(await TelegramMessageConverter.yiri2target(node.message_chain, bot)) return components @staticmethod - async def target2yiri( - message: telegram.Message, bot: telegram.Bot, bot_account_id: str - ): + async def target2yiri(message: telegram.Message, bot: telegram.Bot, bot_account_id: str): message_components = [] def parse_message_text(text: str) -> list[platform_message.MessageComponent]: @@ -103,9 +95,7 @@ class TelegramEventConverter(adapter.EventConverter): @staticmethod async def target2yiri(event: Update, bot: telegram.Bot, bot_account_id: str): - lb_message = await TelegramMessageConverter.target2yiri( - event.message, bot, bot_account_id - ) + lb_message = await TelegramMessageConverter.target2yiri(event.message, bot, bot_account_id) if event.effective_chat.type == 'private': return platform_events.FriendMessage( @@ -166,9 +156,7 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): return try: - lb_event = await self.event_converter.target2yiri( - update, self.bot, self.bot_account_id - ) + lb_event = await self.event_converter.target2yiri(update, self.bot, self.bot_account_id) await self.listeners[type(lb_event)](lb_event, self) except Exception: print(traceback.format_exc()) @@ -176,14 +164,10 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): self.application = ApplicationBuilder().token(self.config['token']).build() self.bot = self.application.bot self.application.add_handler( - MessageHandler( - filters.TEXT | (filters.COMMAND) | filters.PHOTO, telegram_callback - ) + MessageHandler(filters.TEXT | (filters.COMMAND) | filters.PHOTO, telegram_callback) ) - async def send_message( - self, target_type: str, target_id: str, message: platform_message.MessageChain - ): + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): pass async def reply_message( @@ -210,9 +194,7 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): if self.config['markdown_card'] is True: args['parse_mode'] = 'MarkdownV2' if quote_origin: - args['reply_to_message_id'] = ( - message_source.source_platform_object.message.id - ) + args['reply_to_message_id'] = message_source.source_platform_object.message.id await self.bot.send_message(**args) @@ -222,18 +204,14 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): self.listeners[event_type] = callback def unregister_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): self.listeners.pop(event_type) diff --git a/pkg/platform/sources/wecom.py b/pkg/platform/sources/wecom.py index 53878062..5c02a632 100644 --- a/pkg/platform/sources/wecom.py +++ b/pkg/platform/sources/wecom.py @@ -18,9 +18,7 @@ from ...utils import image class WecomMessageConverter(adapter.MessageConverter): @staticmethod - async def yiri2target( - message_chain: platform_message.MessageChain, bot: WecomClient - ): + async def yiri2target(message_chain: platform_message.MessageChain, bot: WecomClient): content_list = [] for msg in message_chain: @@ -40,13 +38,7 @@ class WecomMessageConverter(adapter.MessageConverter): ) elif type(msg) is platform_message.Forward: for node in msg.node_list: - content_list.extend( - ( - await WecomMessageConverter.yiri2target( - node.message_chain, bot - ) - ) - ) + content_list.extend((await WecomMessageConverter.yiri2target(node.message_chain, bot))) else: content_list.append( { @@ -60,9 +52,7 @@ class WecomMessageConverter(adapter.MessageConverter): @staticmethod async def target2yiri(message: str, message_id: int = -1): yiri_msg_list = [] - yiri_msg_list.append( - platform_message.Source(id=message_id, time=datetime.datetime.now()) - ) + yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now())) yiri_msg_list.append(platform_message.Plain(text=message)) chain = platform_message.MessageChain(yiri_msg_list) @@ -72,15 +62,9 @@ class WecomMessageConverter(adapter.MessageConverter): @staticmethod async def target2yiri_image(picurl: str, message_id: int = -1): yiri_msg_list = [] - yiri_msg_list.append( - platform_message.Source(id=message_id, time=datetime.datetime.now()) - ) + yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now())) image_base64, image_format = await image.get_wecom_image_base64(pic_url=picurl) - yiri_msg_list.append( - platform_message.Image( - base64=f'data:image/{image_format};base64,{image_base64}' - ) - ) + yiri_msg_list.append(platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}')) chain = platform_message.MessageChain(yiri_msg_list) return chain @@ -88,9 +72,7 @@ class WecomMessageConverter(adapter.MessageConverter): class WecomEventConverter: @staticmethod - async def yiri2target( - event: platform_events.Event, bot_account_id: int, bot: WecomClient - ) -> WecomEvent: + async def yiri2target(event: platform_events.Event, bot_account_id: int, bot: WecomClient) -> WecomEvent: # only for extracting user information if type(event) is platform_events.GroupMessage: @@ -124,18 +106,14 @@ class WecomEventConverter: """ # 转换消息链 if event.type == 'text': - yiri_chain = await WecomMessageConverter.target2yiri( - event.message, event.message_id - ) + yiri_chain = await WecomMessageConverter.target2yiri(event.message, event.message_id) friend = platform_entities.Friend( id=f'u{event.user_id}', nickname=str(event.agent_id), remark='', ) - return platform_events.FriendMessage( - sender=friend, message_chain=yiri_chain, time=event.timestamp - ) + return platform_events.FriendMessage(sender=friend, message_chain=yiri_chain, time=event.timestamp) elif event.type == 'image': friend = platform_entities.Friend( id=f'u{event.user_id}', @@ -143,13 +121,9 @@ class WecomEventConverter: remark='', ) - yiri_chain = await WecomMessageConverter.target2yiri_image( - picurl=event.picurl, message_id=event.message_id - ) + yiri_chain = await WecomMessageConverter.target2yiri_image(picurl=event.picurl, message_id=event.message_id) - return platform_events.FriendMessage( - sender=friend, message_chain=yiri_chain, time=event.timestamp - ) + return platform_events.FriendMessage(sender=friend, message_chain=yiri_chain, time=event.timestamp) class WecomAdapter(adapter.MessagePlatformAdapter): @@ -190,26 +164,18 @@ class WecomAdapter(adapter.MessagePlatformAdapter): message: platform_message.MessageChain, quote_origin: bool = False, ): - Wecom_event = await WecomEventConverter.yiri2target( - message_source, self.bot_account_id, self.bot - ) + Wecom_event = await WecomEventConverter.yiri2target(message_source, self.bot_account_id, self.bot) content_list = await WecomMessageConverter.yiri2target(message, self.bot) fixed_user_id = Wecom_event.user_id # 删掉开头的u fixed_user_id = fixed_user_id[1:] for content in content_list: if content['type'] == 'text': - await self.bot.send_private_msg( - fixed_user_id, Wecom_event.agent_id, content['content'] - ) + await self.bot.send_private_msg(fixed_user_id, Wecom_event.agent_id, content['content']) elif content['type'] == 'image': - await self.bot.send_image( - fixed_user_id, Wecom_event.agent_id, content['media_id'] - ) + await self.bot.send_image(fixed_user_id, Wecom_event.agent_id, content['media_id']) - async def send_message( - self, target_type: str, target_id: str, message: platform_message.MessageChain - ): + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): """企业微信目前只有发送给个人的方法, 构造target_id的方式为前半部分为账户id,后半部分为agent_id,中间使用“|”符号隔开。 """ @@ -220,25 +186,19 @@ class WecomAdapter(adapter.MessagePlatformAdapter): if target_type == 'person': for content in content_list: if content['type'] == 'text': - await self.bot.send_private_msg( - user_id, agent_id, content['content'] - ) + await self.bot.send_private_msg(user_id, agent_id, content['content']) if content['type'] == 'image': await self.bot.send_image(user_id, agent_id, content['media']) def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): async def on_message(event: WecomEvent): self.bot_account_id = event.receiver_id try: - return await callback( - await self.event_converter.target2yiri(event), self - ) + return await callback(await self.event_converter.target2yiri(event), self) except Exception: traceback.print_exc() @@ -265,8 +225,6 @@ class WecomAdapter(adapter.MessagePlatformAdapter): async def unregister_listener( self, event_type: type, - callback: typing.Callable[ - [platform_events.Event, MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], ): return super().unregister_listener(event_type, callback) diff --git a/pkg/platform/sources/wecomcs.py b/pkg/platform/sources/wecomcs.py index 532d7470..94d0e450 100644 --- a/pkg/platform/sources/wecomcs.py +++ b/pkg/platform/sources/wecomcs.py @@ -11,49 +11,47 @@ from pkg.platform.types import events as platform_events, message as platform_me from libs.wecom_customer_service_api.wecomcsevent import WecomCSEvent from pkg.core import app from .. import adapter -from ...core import app -from ..types import message as platform_message -from ..types import events as platform_events from ..types import entities as platform_entities from ...command.errors import ParamNotEnoughError -from ...utils import image + class WecomMessageConverter(adapter.MessageConverter): - @staticmethod - async def yiri2target( - message_chain: platform_message.MessageChain, bot: WecomCSClient - ): + async def yiri2target(message_chain: platform_message.MessageChain, bot: WecomCSClient): content_list = [] for msg in message_chain: if type(msg) is platform_message.Plain: - content_list.append({ - "type": "text", - "content": msg.text, - }) + content_list.append( + { + 'type': 'text', + 'content': msg.text, + } + ) elif type(msg) is platform_message.Image: - content_list.append({ - "type": "image", - "media_id": await bot.get_media_id(msg), - }) + content_list.append( + { + 'type': 'image', + 'media_id': await bot.get_media_id(msg), + } + ) elif type(msg) is platform_message.Forward: for node in msg.node_list: content_list.extend((await WecomMessageConverter.yiri2target(node.message_chain, bot))) else: - content_list.append({ - "type": "text", - "content": str(msg), - }) + content_list.append( + { + 'type': 'text', + 'content': str(msg), + } + ) return content_list @staticmethod async def target2yiri(message: str, message_id: int = -1): yiri_msg_list = [] - yiri_msg_list.append( - platform_message.Source(id=message_id, time=datetime.datetime.now()) - ) + yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now())) yiri_msg_list.append(platform_message.Plain(text=message)) chain = platform_message.MessageChain(yiri_msg_list) @@ -63,21 +61,16 @@ class WecomMessageConverter(adapter.MessageConverter): @staticmethod async def target2yiri_image(picurl: str, message_id: int = -1): yiri_msg_list = [] - yiri_msg_list.append( - platform_message.Source(id=message_id, time=datetime.datetime.now()) - ) + yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now())) yiri_msg_list.append(platform_message.Image(base64=picurl)) chain = platform_message.MessageChain(yiri_msg_list) - + return chain class WecomEventConverter: - @staticmethod - async def yiri2target( - event: platform_events.Event, bot_account_id: int, bot: WecomCSClient - ) -> WecomCSEvent: + async def yiri2target(event: platform_events.Event, bot_account_id: int, bot: WecomCSClient) -> WecomCSEvent: # only for extracting user information if type(event) is platform_events.GroupMessage: @@ -98,29 +91,25 @@ class WecomEventConverter: platform_events.FriendMessage: 转换后的 FriendMessage 对象。 """ # 转换消息链 - if event.type == "text": - yiri_chain = await WecomMessageConverter.target2yiri( - event.message, event.message_id - ) + if event.type == 'text': + yiri_chain = await WecomMessageConverter.target2yiri(event.message, event.message_id) friend = platform_entities.Friend( - id=f"u{event.user_id}", + id=f'u{event.user_id}', nickname=str(event.user_id), - remark="", + remark='', ) return platform_events.FriendMessage( sender=friend, message_chain=yiri_chain, time=event.timestamp, source_platform_object=event ) - elif event.type == "image": + elif event.type == 'image': friend = platform_entities.Friend( - id=f"u{event.user_id}", + id=f'u{event.user_id}', nickname=str(event.user_id), - remark="", + remark='', ) - yiri_chain = await WecomMessageConverter.target2yiri_image( - picurl=event.picurl, message_id=event.message_id - ) + yiri_chain = await WecomMessageConverter.target2yiri_image(picurl=event.picurl, message_id=event.message_id) return platform_events.FriendMessage( sender=friend, message_chain=yiri_chain, time=event.timestamp, source_platform_object=event @@ -128,7 +117,6 @@ class WecomEventConverter: class WecomCSAdapter(adapter.MessagePlatformAdapter): - bot: WecomCSClient ap: app.Application bot_account_id: str @@ -142,20 +130,20 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter): self.ap = ap required_keys = [ - "corpid", - "secret", - "token", - "EncodingAESKey", + 'corpid', + 'secret', + 'token', + 'EncodingAESKey', ] missing_keys = [key for key in required_keys if key not in config] if missing_keys: - raise ParamNotEnoughError("企业微信客服缺少相关配置项,请查看文档或联系管理员") + raise ParamNotEnoughError('企业微信客服缺少相关配置项,请查看文档或联系管理员') self.bot = WecomCSClient( - corpid=config["corpid"], - secret=config["secret"], - token=config["token"], - EncodingAESKey=config["EncodingAESKey"], + corpid=config['corpid'], + secret=config['secret'], + token=config['token'], + EncodingAESKey=config['EncodingAESKey'], ) async def reply_message( @@ -164,40 +152,36 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter): message: platform_message.MessageChain, quote_origin: bool = False, ): - - Wecom_event = await WecomEventConverter.yiri2target( - message_source, self.bot_account_id, self.bot - ) + Wecom_event = await WecomEventConverter.yiri2target(message_source, self.bot_account_id, self.bot) content_list = await WecomMessageConverter.yiri2target(message, self.bot) - + for content in content_list: - if content["type"] == "text": - await self.bot.send_text_msg(open_kfid=Wecom_event.receiver_id,external_userid=Wecom_event.user_id,msgid=Wecom_event.message_id,content=content["content"]) - - async def send_message( - self, target_type: str, target_id: str, message: platform_message.MessageChain - ): + if content['type'] == 'text': + await self.bot.send_text_msg( + open_kfid=Wecom_event.receiver_id, + external_userid=Wecom_event.user_id, + msgid=Wecom_event.message_id, + content=content['content'], + ) + + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): pass def register_listener( self, event_type: typing.Type[platform_events.Event], - callback: typing.Callable[ - [platform_events.Event, adapter.MessagePlatformAdapter], None - ], + callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ): async def on_message(event: WecomCSEvent): self.bot_account_id = event.receiver_id try: - return await callback( - await self.event_converter.target2yiri(event), self - ) + return await callback(await self.event_converter.target2yiri(event), self) except: traceback.print_exc() if event_type == platform_events.FriendMessage: - self.bot.on_message("text")(on_message) - self.bot.on_message("image")(on_message) + self.bot.on_message('text')(on_message) + self.bot.on_message('image')(on_message) elif event_type == platform_events.GroupMessage: pass @@ -207,8 +191,8 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter): await asyncio.sleep(1) await self.bot.run_task( - host="0.0.0.0", - port=self.config["port"], + host='0.0.0.0', + port=self.config['port'], shutdown_trigger=shutdown_trigger_placeholder, ) @@ -220,4 +204,4 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter): event_type: type, callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], ): - return super().unregister_listener(event_type, callback) \ No newline at end of file + return super().unregister_listener(event_type, callback) diff --git a/pkg/platform/types/base.py b/pkg/platform/types/base.py index 9e31bafe..da58d4ed 100644 --- a/pkg/platform/types/base.py +++ b/pkg/platform/types/base.py @@ -31,10 +31,7 @@ class PlatformBaseModel(BaseModel, metaclass=PlatformMetaclass): def __repr__(self) -> str: return ( - self.__class__.__name__ - + '(' - + ', '.join((f'{k}={repr(v)}' for k, v in self.__dict__.items() if v)) - + ')' + self.__class__.__name__ + '(' + ', '.join((f'{k}={repr(v)}' for k, v in self.__dict__.items() if v)) + ')' ) class Config: diff --git a/pkg/platform/types/events.py b/pkg/platform/types/events.py index 1a724beb..5ffccb9b 100644 --- a/pkg/platform/types/events.py +++ b/pkg/platform/types/events.py @@ -25,13 +25,7 @@ class Event(pydantic.BaseModel): return ( self.__class__.__name__ + '(' - + ', '.join( - ( - f'{k}={repr(v)}' - for k, v in self.__dict__.items() - if k != 'type' and v - ) - ) + + ', '.join((f'{k}={repr(v)}' for k, v in self.__dict__.items() if k != 'type' and v)) + ')' ) diff --git a/pkg/platform/types/message.py b/pkg/platform/types/message.py index 3693b76f..8412e8a4 100644 --- a/pkg/platform/types/message.py +++ b/pkg/platform/types/message.py @@ -51,13 +51,7 @@ class MessageComponent(PlatformIndexedModel, metaclass=MessageComponentMetaclass return ( self.__class__.__name__ + '(' - + ', '.join( - ( - f'{k}={repr(v)}' - for k, v in self.__dict__.items() - if k != 'type' and v - ) - ) + + ', '.join((f'{k}={repr(v)}' for k, v in self.__dict__.items() if k != 'type' and v)) + ')' ) @@ -65,14 +59,10 @@ class MessageComponent(PlatformIndexedModel, metaclass=MessageComponentMetaclass # 解析参数列表,将位置参数转化为具名参数 parameter_names = self.__parameter_names__ if len(args) > len(parameter_names): - raise TypeError( - f'`{self.type}`需要{len(parameter_names)}个参数,但传入了{len(args)}个。' - ) + raise TypeError(f'`{self.type}`需要{len(parameter_names)}个参数,但传入了{len(args)}个。') for name, value in zip(parameter_names, args): if name in kwargs: - raise TypeError( - f'在 `{self.type}` 中,具名参数 `{name}` 与位置参数重复。' - ) + raise TypeError(f'在 `{self.type}` 中,具名参数 `{name}` 与位置参数重复。') kwargs[name] = value super().__init__(**kwargs) @@ -140,9 +130,7 @@ class MessageChain(PlatformBaseModel): elif isinstance(msg, str): result.append(Plain(msg)) else: - raise TypeError( - f'消息链中元素需为 dict 或 str 或 MessageComponent,当前类型:{type(msg)}' - ) + raise TypeError(f'消息链中元素需为 dict 或 str 或 MessageComponent,当前类型:{type(msg)}') return result @pydantic.validator('__root__', always=True, pre=True) @@ -175,9 +163,7 @@ class MessageChain(PlatformBaseModel): def __iter__(self): yield from self.__root__ - def get_first( - self, t: typing.Type[TMessageComponent] - ) -> typing.Optional[TMessageComponent]: + def get_first(self, t: typing.Type[TMessageComponent]) -> typing.Optional[TMessageComponent]: """获取消息链中第一个符合类型的消息组件。""" for component in self: if isinstance(component, t): @@ -191,9 +177,7 @@ class MessageChain(PlatformBaseModel): def __getitem__(self, index: slice) -> typing.List[MessageComponent]: ... @typing.overload - def __getitem__( - self, index: typing.Type[TMessageComponent] - ) -> typing.List[TMessageComponent]: ... + def __getitem__(self, index: typing.Type[TMessageComponent]) -> typing.List[TMessageComponent]: ... @typing.overload def __getitem__( @@ -208,17 +192,13 @@ class MessageChain(PlatformBaseModel): typing.Type[TMessageComponent], typing.Tuple[typing.Type[TMessageComponent], int], ], - ) -> typing.Union[ - MessageComponent, typing.List[MessageComponent], typing.List[TMessageComponent] - ]: + ) -> typing.Union[MessageComponent, typing.List[MessageComponent], typing.List[TMessageComponent]]: return self.get(index) def __setitem__( self, key: typing.Union[int, slice], - value: typing.Union[ - MessageComponent, str, typing.Iterable[typing.Union[MessageComponent, str]] - ], + value: typing.Union[MessageComponent, str, typing.Iterable[typing.Union[MessageComponent, str]]], ): if isinstance(value, str): value = Plain(value) @@ -234,9 +214,7 @@ class MessageChain(PlatformBaseModel): def has( self, - sub: typing.Union[ - MessageComponent, typing.Type[MessageComponent], 'MessageChain', str - ], + sub: typing.Union[MessageComponent, typing.Type[MessageComponent], 'MessageChain', str], ) -> bool: """判断消息链中: 1. 是否有某个消息组件。 @@ -271,9 +249,7 @@ class MessageChain(PlatformBaseModel): def __len__(self) -> int: return len(self.__root__) - def __add__( - self, other: typing.Union['MessageChain', MessageComponent, str] - ) -> 'MessageChain': + def __add__(self, other: typing.Union['MessageChain', MessageComponent, str]) -> 'MessageChain': if isinstance(other, MessageChain): return self.__class__(self.__root__ + other.__root__) if isinstance(other, str): @@ -286,9 +262,7 @@ class MessageChain(PlatformBaseModel): if isinstance(other, MessageComponent): return self.__class__([other] + self.__root__) if isinstance(other, str): - return self.__class__( - [typing.cast(MessageComponent, Plain(other))] + self.__root__ - ) + return self.__class__([typing.cast(MessageComponent, Plain(other))] + self.__root__) return NotImplemented def __mul__(self, other: int): @@ -346,9 +320,7 @@ class MessageChain(PlatformBaseModel): return self.__root__.index(x, i, j) raise TypeError(f'类型不匹配,当前类型:{type(x)}') - def count( - self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]] - ) -> int: + def count(self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]]) -> int: """返回消息链中 x 出现的次数。 Args: @@ -443,9 +415,7 @@ class MessageChain(PlatformBaseModel): @classmethod def join(cls, *args: typing.Iterable[typing.Union[str, MessageComponent]]): - return cls( - Plain(c) if isinstance(c, str) else c for c in itertools.chain(*args) - ) + return cls(Plain(c) if isinstance(c, str) else c for c in itertools.chain(*args)) @property def source(self) -> typing.Optional['Source']: @@ -557,11 +527,7 @@ class Image(MessageComponent): """图片的 Base64 编码。""" def __eq__(self, other): - return ( - isinstance(other, Image) - and self.type == other.type - and self.uuid == other.uuid - ) + return isinstance(other, Image) and self.type == other.type and self.uuid == other.uuid def __str__(self): return '[图片]' @@ -818,9 +784,7 @@ class ForwardMessageNode(pydantic.BaseModel): Returns: ForwardMessageNode: 生成的一条消息。 """ - return ForwardMessageNode( - sender_id=sender.id, sender_name=sender.get_name(), message_chain=message - ) + return ForwardMessageNode(sender_id=sender.id, sender_name=sender.get_name(), message_chain=message) class ForwardMessageDiaplay(pydantic.BaseModel): diff --git a/pkg/plugin/context.py b/pkg/plugin/context.py index cc95adaa..dfd691f3 100644 --- a/pkg/plugin/context.py +++ b/pkg/plugin/context.py @@ -165,9 +165,7 @@ class APIHost: langbot_version = '' try: - langbot_version = ( - self.ap.ver_mgr.get_current_version() - ) # 从updater模块获取版本号 + langbot_version = self.ap.ver_mgr.get_current_version() # 从updater模块获取版本号 except Exception: return False @@ -237,9 +235,7 @@ class EventContext: message_source=self.event.query.message_event, message=message_chain ) - async def send_message( - self, target_type: str, target_id: str, message: platform_message.MessageChain - ): + async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): """主动发送消息 Args: @@ -247,9 +243,7 @@ class EventContext: target_id (str): 目标ID message (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链 """ - await self.event.query.adapter.send_message( - target_type=target_type, target_id=target_id, message=message - ) + await self.event.query.adapter.send_message(target_type=target_type, target_id=target_id, message=message) def prevent_postorder(self): """阻止后续插件执行""" @@ -378,8 +372,7 @@ class RuntimeContainer(pydantic.BaseModel): 'priority': self.priority, 'config_schema': self.config_schema, 'event_handlers': { - event_name.__name__: handler.__name__ - for event_name, handler in self.event_handlers.items() + event_name.__name__: handler.__name__ for event_name, handler in self.event_handlers.items() }, 'tools': [ { diff --git a/pkg/plugin/installers/github.py b/pkg/plugin/installers/github.py index a867b04d..df247219 100644 --- a/pkg/plugin/installers/github.py +++ b/pkg/plugin/installers/github.py @@ -58,9 +58,7 @@ class GitHubRepoInstaller(installer.PluginInstaller): ssl=ssl_context, # 使用自定义SSL上下文来验证证书 ) as resp: if resp.status != 200: - raise errors.PluginInstallerError( - f'下载源码失败: {await resp.text()}' - ) + raise errors.PluginInstallerError(f'下载源码失败: {await resp.text()}') zip_resp = await resp.read() if await aiofiles_os.path.exists('temp/' + target_path): @@ -101,9 +99,7 @@ class GitHubRepoInstaller(installer.PluginInstaller): ): """安装插件""" task_context.trace('下载插件源码...', 'install-plugin') - repo_label = await self.download_plugin_source_code( - plugin_source, 'plugins/', task_context - ) + repo_label = await self.download_plugin_source_code(plugin_source, 'plugins/', task_context) task_context.trace('安装插件依赖...', 'install-plugin') await self.install_requirements('plugins/' + repo_label) task_context.trace('完成.', 'install-plugin') diff --git a/pkg/plugin/loaders/classic.py b/pkg/plugin/loaders/classic.py index 857a7b9c..8aa7382b 100644 --- a/pkg/plugin/loaders/classic.py +++ b/pkg/plugin/loaders/classic.py @@ -35,16 +35,12 @@ class PluginLoader(loader.PluginLoader): def register( self, name: str, description: str, version: str, author: str - ) -> typing.Callable[ - [typing.Type[context.BasePlugin]], typing.Type[context.BasePlugin] - ]: + ) -> typing.Callable[[typing.Type[context.BasePlugin]], typing.Type[context.BasePlugin]]: self.ap.logger.debug(f'注册插件 {name} {version} by {author}') container = context.RuntimeContainer( plugin_name=name, plugin_label=discover_engine.I18nString(en_US=name, zh_CN=name), - plugin_description=discover_engine.I18nString( - en_US=description, zh_CN=description - ), + plugin_description=discover_engine.I18nString(en_US=description, zh_CN=description), plugin_version=version, plugin_author=author, plugin_repository='', @@ -64,16 +60,12 @@ class PluginLoader(loader.PluginLoader): # 过时 # 最早将于 v3.4 版本移除 - def on( - self, event: typing.Type[events.BaseEventModel] - ) -> typing.Callable[[typing.Callable], typing.Callable]: + def on(self, event: typing.Type[events.BaseEventModel]) -> typing.Callable[[typing.Callable], typing.Callable]: """注册过时的事件处理器""" self.ap.logger.debug(f'注册事件处理器 {event.__name__}') def wrapper(func: typing.Callable) -> typing.Callable: - async def handler( - plugin: context.BasePlugin, ctx: context.EventContext - ) -> None: + async def handler(plugin: context.BasePlugin, ctx: context.EventContext) -> None: args = { 'host': ctx.host, 'event': ctx, @@ -104,15 +96,9 @@ class PluginLoader(loader.PluginLoader): def wrapper(func: typing.Callable) -> typing.Callable: function_schema = funcschema.get_func_schema(func) - function_name = ( - self._current_container.plugin_name - + '-' - + (func.__name__ if name is None else name) - ) + function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name) - async def handler( - plugin: context.BasePlugin, query: core_entities.Query, *args, **kwargs - ): + async def handler(plugin: context.BasePlugin, query: core_entities.Query, *args, **kwargs): return func(*args, **kwargs) llm_function = tools_entities.LLMFunction( @@ -129,9 +115,7 @@ class PluginLoader(loader.PluginLoader): return wrapper - def handler( - self, event: typing.Type[events.BaseEventModel] - ) -> typing.Callable[[typing.Callable], typing.Callable]: + def handler(self, event: typing.Type[events.BaseEventModel]) -> typing.Callable[[typing.Callable], typing.Callable]: """注册事件处理器""" self.ap.logger.debug(f'注册事件处理器 {event.__name__}') @@ -161,11 +145,7 @@ class PluginLoader(loader.PluginLoader): return func function_schema = funcschema.get_func_schema(func) - function_name = ( - self._current_container.plugin_name - + '-' - + (func.__name__ if name is None else name) - ) + function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name) llm_function = tools_entities.LLMFunction( name=function_name, @@ -193,9 +173,7 @@ class PluginLoader(loader.PluginLoader): else: try: self._current_pkg_path = 'plugins/' + path_prefix - self._current_module_path = ( - 'plugins/' + path_prefix + item.name + '.py' - ) + self._current_module_path = 'plugins/' + path_prefix + item.name + '.py' self._current_container = None @@ -205,9 +183,7 @@ class PluginLoader(loader.PluginLoader): self.plugins.append(self._current_container) self.ap.logger.debug(f'插件 {self._current_container} 已加载') except Exception: - self.ap.logger.error( - f'加载插件模块 {prefix + item.name} 时发生错误' - ) + self.ap.logger.error(f'加载插件模块 {prefix + item.name} 时发生错误') traceback.print_exc() async def load_plugins(self): diff --git a/pkg/plugin/loaders/manifest.py b/pkg/plugin/loaders/manifest.py index b634c5b5..cce6c9e3 100644 --- a/pkg/plugin/loaders/manifest.py +++ b/pkg/plugin/loaders/manifest.py @@ -19,9 +19,7 @@ class PluginManifestLoader(loader.PluginLoader): def __init__(self, ap: app.Application): super().__init__(ap) - def handler( - self, event: typing.Type[events.BaseEventModel] - ) -> typing.Callable[[typing.Callable], typing.Callable]: + def handler(self, event: typing.Type[events.BaseEventModel]) -> typing.Callable[[typing.Callable], typing.Callable]: """注册事件处理器""" self.ap.logger.debug(f'注册事件处理器 {event.__name__}') @@ -41,11 +39,7 @@ class PluginManifestLoader(loader.PluginLoader): def wrapper(func: typing.Callable) -> typing.Callable: function_schema = funcschema.get_func_schema(func) - function_name = ( - self._current_container.plugin_name - + '-' - + (func.__name__ if name is None else name) - ) + function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name) llm_function = tools_entities.LLMFunction( name=function_name, @@ -70,11 +64,7 @@ class PluginManifestLoader(loader.PluginLoader): for plugin_manifest in plugin_manifests: try: - config_schema = ( - plugin_manifest.spec['config'] - if 'config' in plugin_manifest.spec - else [] - ) + config_schema = plugin_manifest.spec['config'] if 'config' in plugin_manifest.spec else [] current_plugin_container = context.RuntimeContainer( plugin_name=plugin_manifest.metadata.name, @@ -83,9 +73,7 @@ class PluginManifestLoader(loader.PluginLoader): plugin_version=plugin_manifest.metadata.version, plugin_author=plugin_manifest.metadata.author, plugin_repository=plugin_manifest.metadata.repository, - main_file=os.path.join( - plugin_manifest.rel_dir, plugin_manifest.execution.python.path - ), + main_file=os.path.join(plugin_manifest.rel_dir, plugin_manifest.execution.python.path), pkg_path=plugin_manifest.rel_dir, config_schema=config_schema, event_handlers={}, @@ -104,7 +92,5 @@ class PluginManifestLoader(loader.PluginLoader): self.plugins.append(current_plugin_container) except Exception: - self.ap.logger.error( - f'加载插件 {plugin_manifest.metadata.name} 时发生错误' - ) + self.ap.logger.error(f'加载插件 {plugin_manifest.metadata.name} 时发生错误') traceback.print_exc() diff --git a/pkg/plugin/manager.py b/pkg/plugin/manager.py index 9dc4cb26..f813d2e2 100644 --- a/pkg/plugin/manager.py +++ b/pkg/plugin/manager.py @@ -83,20 +83,12 @@ class PluginManager: self.ap.logger.debug(f'优先级排序后的插件列表 {self.plugin_containers}') - async def load_plugin_settings( - self, plugin_containers: list[context.RuntimeContainer] - ): + async def load_plugin_settings(self, plugin_containers: list[context.RuntimeContainer]): for plugin_container in plugin_containers: result = await self.ap.persistence_mgr.execute_async( sqlalchemy.select(persistence_plugin.PluginSetting) - .where( - persistence_plugin.PluginSetting.plugin_author - == plugin_container.plugin_author - ) - .where( - persistence_plugin.PluginSetting.plugin_name - == plugin_container.plugin_name - ) + .where(persistence_plugin.PluginSetting.plugin_author == plugin_container.plugin_author) + .where(persistence_plugin.PluginSetting.plugin_name == plugin_container.plugin_name) ) setting = result.first() @@ -111,9 +103,7 @@ class PluginManager: } await self.ap.persistence_mgr.execute_async( - sqlalchemy.insert(persistence_plugin.PluginSetting).values( - **new_setting_data - ) + sqlalchemy.insert(persistence_plugin.PluginSetting).values(**new_setting_data) ) continue else: @@ -121,20 +111,12 @@ class PluginManager: plugin_container.priority = setting.priority plugin_container.plugin_config = setting.config - async def dump_plugin_container_setting( - self, plugin_container: context.RuntimeContainer - ): + async def dump_plugin_container_setting(self, plugin_container: context.RuntimeContainer): """保存单个插件容器的设置到数据库""" await self.ap.persistence_mgr.execute_async( sqlalchemy.update(persistence_plugin.PluginSetting) - .where( - persistence_plugin.PluginSetting.plugin_author - == plugin_container.plugin_author - ) - .where( - persistence_plugin.PluginSetting.plugin_name - == plugin_container.plugin_name - ) + .where(persistence_plugin.PluginSetting.plugin_author == plugin_container.plugin_author) + .where(persistence_plugin.PluginSetting.plugin_name == plugin_container.plugin_name) .values( enabled=plugin_container.enabled, priority=plugin_container.priority, @@ -247,20 +229,14 @@ class PluginManager: emitted_plugins: list[context.RuntimeContainer] = [] - for plugin in self.plugins( - enabled=True, status=context.RuntimeContainerStatus.INITIALIZED - ): + for plugin in self.plugins(enabled=True, status=context.RuntimeContainerStatus.INITIALIZED): if event.__class__ in plugin.event_handlers: - self.ap.logger.debug( - f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__}' - ) + self.ap.logger.debug(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__}') is_prevented_default_before_call = ctx.is_prevented_default() try: - await plugin.event_handlers[event.__class__]( - plugin.plugin_inst, ctx - ) + await plugin.event_handlers[event.__class__](plugin.plugin_inst, ctx) except Exception as e: self.ap.logger.error( f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__} 时发生错误: {e}' @@ -270,23 +246,17 @@ class PluginManager: emitted_plugins.append(plugin) if not is_prevented_default_before_call and ctx.is_prevented_default(): - self.ap.logger.debug( - f'插件 {plugin.plugin_name} 阻止了默认行为执行' - ) + self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行') if ctx.is_prevented_postorder(): - self.ap.logger.debug( - f'插件 {plugin.plugin_name} 阻止了后序插件的执行' - ) + self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行') break for key in ctx.__return_value__.keys(): if hasattr(ctx.event, key): setattr(ctx.event, key, ctx.__return_value__[key][0]) - self.ap.logger.debug( - f'事件 {event.__class__.__name__}({ctx.eid}) 处理完成,返回值 {ctx.__return_value__}' - ) + self.ap.logger.debug(f'事件 {event.__class__.__name__}({ctx.eid}) 处理完成,返回值 {ctx.__return_value__}') # TODO statistics @@ -330,9 +300,7 @@ class PluginManager: for plugin in self.plugin_containers: await self.dump_plugin_container_setting(plugin) - async def set_plugin_config( - self, plugin_container: context.RuntimeContainer, new_config: dict - ): + async def set_plugin_config(self, plugin_container: context.RuntimeContainer, new_config: dict): plugin_container.plugin_config = new_config plugin_container.plugin_inst.config = new_config diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index ff95b128..94b812d9 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -80,17 +80,13 @@ class Message(pydantic.BaseModel): def readable_str(self) -> str: if self.content is not None: - return ( - str(self.role) + ': ' + str(self.get_content_platform_message_chain()) - ) + return str(self.role) + ': ' + str(self.get_content_platform_message_chain()) elif self.tool_calls is not None: return f'调用工具: {self.tool_calls[0].id}' else: return '未知消息' - def get_content_platform_message_chain( - self, prefix_text: str = '' - ) -> platform_message.MessageChain | None: + def get_content_platform_message_chain(self, prefix_text: str = '') -> platform_message.MessageChain | None: """将内容转换为平台消息 MessageChain 对象 Args: @@ -100,9 +96,7 @@ class Message(pydantic.BaseModel): if self.content is None: return None elif isinstance(self.content, str): - return platform_message.MessageChain( - [platform_message.Plain(prefix_text + self.content)] - ) + return platform_message.MessageChain([platform_message.Plain(prefix_text + self.content)]) elif isinstance(self.content, list): mc = [] for ce in self.content: diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index 25a79fec..e37e21cb 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -43,16 +43,12 @@ class ModelManager: self.requester_dict = {} async def initialize(self): - self.requester_components = self.ap.discover.get_components_by_kind( - 'LLMAPIRequester' - ) + self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester') # forge requester class dict requester_dict: dict[str, type[requester.LLMAPIRequester]] = {} for component in self.requester_components: - requester_dict[component.metadata.name] = ( - component.get_python_component_class() - ) + requester_dict[component.metadata.name] = component.get_python_component_class() self.requester_dict = requester_dict @@ -65,9 +61,7 @@ class ModelManager: self.llm_models = [] # llm models - result = await self.ap.persistence_mgr.execute_async( - sqlalchemy.select(persistence_model.LLMModel) - ) + result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel)) llm_models = result.all() @@ -77,9 +71,7 @@ class ModelManager: async def load_llm_model( self, - model_info: persistence_model.LLMModel - | sqlalchemy.Row[persistence_model.LLMModel] - | dict, + model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict, ): """加载模型""" @@ -88,9 +80,7 @@ class ModelManager: elif isinstance(model_info, dict): model_info = persistence_model.LLMModel(**model_info) - requester_inst = self.requester_dict[model_info.requester]( - ap=self.ap, config=model_info.requester_config - ) + requester_inst = self.requester_dict[model_info.requester](ap=self.ap, config=model_info.requester_config) await requester_inst.initialize() @@ -136,9 +126,7 @@ class ModelManager: return component.to_plain_dict() return None - def get_available_requester_manifest_by_name( - self, name: str - ) -> engine.Component | None: + def get_available_requester_manifest_by_name(self, name: str) -> engine.Component | None: """通过名称获取请求器清单""" for component in self.requester_components: if component.metadata.name == name: diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.py b/pkg/provider/modelmgr/requesters/anthropicmsgs.py index 2472da04..38573854 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.py +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.py @@ -73,9 +73,7 @@ class AnthropicMessages(requester.LLMAPIRequester): if system_role_message: messages.pop(i) - if isinstance(system_role_message, llm_entities.Message) and isinstance( - system_role_message.content, str - ): + if isinstance(system_role_message, llm_entities.Message) and isinstance(system_role_message.content, str): args['system'] = system_role_message.content req_messages = [] @@ -106,9 +104,7 @@ class AnthropicMessages(requester.LLMAPIRequester): elif isinstance(m.content, list): for i, ce in enumerate(m.content): if ce.type == 'image_base64': - image_b64, image_format = await image.extract_b64_and_format( - ce.image_base64 - ) + image_b64, image_format = await image.extract_b64_and_format(ce.image_base64) alter_image_ele = { 'type': 'image', @@ -156,9 +152,7 @@ class AnthropicMessages(requester.LLMAPIRequester): for block in resp.content: if block.type == 'thinking': - args['content'] = ( - '' + block.thinking + '\n' + args['content'] - ) + args['content'] = '' + block.thinking + '\n' + args['content'] elif block.type == 'text': args['content'] += block.text elif block.type == 'tool_use': @@ -166,9 +160,7 @@ class AnthropicMessages(requester.LLMAPIRequester): tool_call = llm_entities.ToolCall( id=block.id, type='function', - function=llm_entities.FunctionCall( - name=block.name, arguments=json.dumps(block.input) - ), + function=llm_entities.FunctionCall(name=block.name, arguments=json.dumps(block.input)), ) if 'tool_calls' not in args: args['tool_calls'] = [] diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index de350739..513086e5 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -28,9 +28,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): api_key='', base_url=self.requester_cfg['base_url'].replace(' ', ''), timeout=self.requester_cfg['timeout'], - http_client=httpx.AsyncClient( - trust_env=True, timeout=self.requester_cfg['timeout'] - ), + http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']), ) async def _req( @@ -50,20 +48,11 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None: chatcmpl_message['role'] = 'assistant' - reasoning_content = ( - chatcmpl_message['reasoning_content'] - if 'reasoning_content' in chatcmpl_message - else None - ) + reasoning_content = chatcmpl_message['reasoning_content'] if 'reasoning_content' in chatcmpl_message else None # deepseek的reasoner模型 if reasoning_content is not None: - chatcmpl_message['content'] = ( - '\n' - + reasoning_content - + '\n\n' - + chatcmpl_message['content'] - ) + chatcmpl_message['content'] = '\n' + reasoning_content + '\n\n' + chatcmpl_message['content'] message = llm_entities.Message(**chatcmpl_message) @@ -124,10 +113,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): content = msg_dict.get('content') if isinstance(content, list): # 检查 content 列表中是否每个部分都是文本 - if all( - isinstance(part, dict) and part.get('type') == 'text' - for part in content - ): + if all(isinstance(part, dict) and part.get('type') == 'text' for part in content): # 将所有文本部分合并为一个字符串 msg_dict['content'] = '\n'.join(part['text'] for part in content) req_messages.append(msg_dict) diff --git a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py index 8f51241e..c8be8a01 100644 --- a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py +++ b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py @@ -1,23 +1,17 @@ from __future__ import annotations - + import asyncio import typing -import json -import base64 -from typing import AsyncGenerator import openai import openai.types.chat.chat_completion as chat_completion import openai.types.chat.chat_completion_message_tool_call as chat_completion_message_tool_call import httpx -import aiohttp -import async_lru from .. import entities, errors, requester from ....core import entities as core_entities, app from ... import entities as llm_entities from ...tools import entities as tools_entities -from ....utils import image class ModelScopeChatCompletions(requester.LLMAPIRequester): @@ -33,26 +27,22 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester): self.requester_cfg = self.ap.provider_cfg.data['requester']['modelscope-chat-completions'] async def initialize(self): - self.client = openai.AsyncClient( - api_key="", + api_key='', base_url=self.requester_cfg['base-url'], timeout=self.requester_cfg['timeout'], - http_client=httpx.AsyncClient( - trust_env=True, - timeout=self.requester_cfg['timeout'] - ) + http_client=httpx.AsyncClient(trust_env=True, timeout=self.requester_cfg['timeout']), ) async def _req( self, args: dict, ) -> chat_completion.ChatCompletion: - args["stream"] = True + args['stream'] = True chunk = None - pending_content = "" + pending_content = '' tool_calls = [] @@ -74,7 +64,7 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester): break else: tool_calls.append(tool_call) - + if chunk.choices[0].finish_reason is not None: break @@ -82,36 +72,41 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester): for tc in tool_calls: function = chat_completion_message_tool_call.Function( - name=tc.function.name, - arguments=tc.function.arguments + name=tc.function.name, arguments=tc.function.arguments ) - real_tool_calls.append(chat_completion_message_tool_call.ChatCompletionMessageToolCall( - id=tc.id, - function=function, - type="function" - )) - - return chat_completion.ChatCompletion( - id=chunk.id, - object="chat.completion", - created=chunk.created, - choices=[ - chat_completion.Choice( - index=0, - message=chat_completion.ChatCompletionMessage( - role="assistant", - content=pending_content, - tool_calls=real_tool_calls if len(real_tool_calls) > 0 else None - ), - finish_reason=chunk.choices[0].finish_reason if hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason is not None else 'stop', - logprobs=chunk.choices[0].logprobs, + real_tool_calls.append( + chat_completion_message_tool_call.ChatCompletionMessageToolCall( + id=tc.id, function=function, type='function' ) - ], - model=chunk.model, - service_tier=chunk.service_tier if hasattr(chunk, 'service_tier') else None, - system_fingerprint=chunk.system_fingerprint if hasattr(chunk, 'system_fingerprint') else None, - usage=chunk.usage if hasattr(chunk, 'usage') else None - ) if chunk else None + ) + + return ( + chat_completion.ChatCompletion( + id=chunk.id, + object='chat.completion', + created=chunk.created, + choices=[ + chat_completion.Choice( + index=0, + message=chat_completion.ChatCompletionMessage( + role='assistant', + content=pending_content, + tool_calls=real_tool_calls if len(real_tool_calls) > 0 else None, + ), + finish_reason=chunk.choices[0].finish_reason + if hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason is not None + else 'stop', + logprobs=chunk.choices[0].logprobs, + ) + ], + model=chunk.model, + service_tier=chunk.service_tier if hasattr(chunk, 'service_tier') else None, + system_fingerprint=chunk.system_fingerprint if hasattr(chunk, 'system_fingerprint') else None, + usage=chunk.usage if hasattr(chunk, 'usage') else None, + ) + if chunk + else None + ) return await self.client.chat.completions.create(**args) async def _make_msg( @@ -138,29 +133,27 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester): self.client.api_key = use_model.token_mgr.get_token() args = self.requester_cfg['args'].copy() - args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + args['model'] = use_model.name if use_model.model_name is None else use_model.model_name if use_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) if tools: - args["tools"] = tools + args['tools'] = tools # 设置此次请求中的messages messages = req_messages.copy() # 检查vision for msg in messages: - if 'content' in msg and isinstance(msg["content"], list): - for me in msg["content"]: - if me["type"] == "image_base64": - me["image_url"] = { - "url": me["image_base64"] - } - me["type"] = "image_url" - del me["image_base64"] + if 'content' in msg and isinstance(msg['content'], list): + for me in msg['content']: + if me['type'] == 'image_base64': + me['image_url'] = {'url': me['image_base64']} + me['type'] = 'image_url' + del me['image_base64'] - args["messages"] = messages + args['messages'] = messages # 发送请求 resp = await self._req(args) @@ -180,12 +173,12 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester): req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 for m in messages: msg_dict = m.dict(exclude_none=True) - content = msg_dict.get("content") + content = msg_dict.get('content') if isinstance(content, list): # 检查 content 列表中是否每个部分都是文本 - if all(isinstance(part, dict) and part.get("type") == "text" for part in content): + if all(isinstance(part, dict) and part.get('type') == 'text' for part in content): # 将所有文本部分合并为一个字符串 - msg_dict["content"] = "\n".join(part["text"] for part in content) + msg_dict['content'] = '\n'.join(part['text'] for part in content) req_messages.append(msg_dict) try: @@ -204,4 +197,4 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester): except openai.RateLimitError as e: raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') except openai.APIError as e: - raise errors.RequesterError(f'请求错误: {e.message}') \ No newline at end of file + raise errors.RequesterError(f'请求错误: {e.message}') diff --git a/pkg/provider/modelmgr/requesters/ollamachat.py b/pkg/provider/modelmgr/requesters/ollamachat.py index 995dd855..00793f82 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.py +++ b/pkg/provider/modelmgr/requesters/ollamachat.py @@ -61,13 +61,9 @@ class OllamaChatCompletions(requester.LLMAPIRequester): msg['content'] = '\n'.join(text_content) msg['images'] = [url.split(',')[1] for url in image_urls] - if ( - 'tool_calls' in msg - ): # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict + if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict for tool_call in msg['tool_calls']: - tool_call['function']['arguments'] = json.loads( - tool_call['function']['arguments'] - ) + tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments']) args['messages'] = messages args['tools'] = [] @@ -80,9 +76,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester): message: llm_entities.Message = await self._make_msg(resp) return message - async def _make_msg( - self, chat_completions: ollama.ChatResponse - ) -> llm_entities.Message: + async def _make_msg(self, chat_completions: ollama.ChatResponse) -> llm_entities.Message: message: ollama.Message = chat_completions.message if message is None: raise ValueError("chat_completions must contain a 'message' field") @@ -122,10 +116,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester): msg_dict: dict = m.dict(exclude_none=True) content: Any = msg_dict.get('content') if isinstance(content, list): - if all( - isinstance(part, dict) and part.get('type') == 'text' - for part in content - ): + if all(isinstance(part, dict) and part.get('type') == 'text' for part in content): msg_dict['content'] = '\n'.join(part['text'] for part in content) req_messages.append(msg_dict) try: diff --git a/pkg/provider/modelmgr/requesters/ppiochatcmpl.py b/pkg/provider/modelmgr/requesters/ppiochatcmpl.py index d0149a80..67c1701a 100644 --- a/pkg/provider/modelmgr/requesters/ppiochatcmpl.py +++ b/pkg/provider/modelmgr/requesters/ppiochatcmpl.py @@ -1,12 +1,11 @@ - from __future__ import annotations import openai -from . import chatcmpl, modelscopechatcmpl -from .. import requester +from . import chatcmpl from ....core import app + class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): """欧派云 ChatCompletion API 请求器""" @@ -17,4 +16,4 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions): def __init__(self, ap: app.Application): self.ap = ap - self.requester_cfg = self.ap.provider_cfg.data['requester']['ppio-chat-completions'] \ No newline at end of file + self.requester_cfg = self.ap.provider_cfg.data['requester']['ppio-chat-completions'] diff --git a/pkg/provider/runner.py b/pkg/provider/runner.py index ccfcee73..a74a2dc5 100644 --- a/pkg/provider/runner.py +++ b/pkg/provider/runner.py @@ -35,8 +35,6 @@ class RequestRunner(abc.ABC): self.pipeline_config = pipeline_config @abc.abstractmethod - async def run( - self, query: core_entities.Query - ) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行请求""" pass diff --git a/pkg/provider/runners/dashscopeapi.py b/pkg/provider/runners/dashscopeapi.py index 92a1eb18..02cb0b51 100644 --- a/pkg/provider/runners/dashscopeapi.py +++ b/pkg/provider/runners/dashscopeapi.py @@ -26,7 +26,9 @@ class DashScopeAPIRunner(runner.RequestRunner): app_type: str # 应用类型 app_id: str # 应用ID api_key: str # API Key - references_quote: str # 引用资料提示(当展示回答来源功能开启时,这个变量会作为引用资料名前的提示,可在provider.json中配置) + references_quote: ( + str # 引用资料提示(当展示回答来源功能开启时,这个变量会作为引用资料名前的提示,可在provider.json中配置) + ) def __init__(self, ap: app.Application, pipeline_config: dict): """初始化""" @@ -42,9 +44,7 @@ class DashScopeAPIRunner(runner.RequestRunner): # 初始化Dashscope 参数配置 self.app_id = self.pipeline_config['ai']['dashscope-app-api']['app-id'] self.api_key = self.pipeline_config['ai']['dashscope-app-api']['api-key'] - self.references_quote = self.pipeline_config['ai']['dashscope-app-api'][ - 'references_quote' - ] + self.references_quote = self.pipeline_config['ai']['dashscope-app-api']['references_quote'] def _replace_references(self, text, references_dict): """阿里云百炼平台的自定义应用支持资料引用,此函数可以将引用标签替换为参考资料""" @@ -65,9 +65,7 @@ class DashScopeAPIRunner(runner.RequestRunner): # 使用 re.sub() 进行替换 return pattern.sub(replacement, text) - async def _preprocess_user_message( - self, query: core_entities.Query - ) -> tuple[str, list[str]]: + async def _preprocess_user_message(self, query: core_entities.Query) -> tuple[str, list[str]]: """预处理用户消息,提取纯文本,阿里云提供的上传文件方法过于复杂,暂不支持上传文件(包括图片)""" plain_text = '' image_ids = [] @@ -91,9 +89,7 @@ class DashScopeAPIRunner(runner.RequestRunner): return plain_text, image_ids - async def _agent_messages( - self, query: core_entities.Query - ) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def _agent_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """Dashscope 智能体对话请求""" # 局部变量 @@ -151,9 +147,7 @@ class DashScopeAPIRunner(runner.RequestRunner): content=pending_content, ) - async def _workflow_messages( - self, query: core_entities.Query - ) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def _workflow_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """Dashscope 工作流对话请求""" # 局部变量 @@ -216,9 +210,7 @@ class DashScopeAPIRunner(runner.RequestRunner): content=pending_content, ) - async def run( - self, query: core_entities.Query - ) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行""" if self.app_type == 'agent': async for msg in self._agent_messages(query): diff --git a/pkg/provider/runners/difysvapi.py b/pkg/provider/runners/difysvapi.py index 0f14533a..26556851 100644 --- a/pkg/provider/runners/difysvapi.py +++ b/pkg/provider/runners/difysvapi.py @@ -26,10 +26,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): self.pipeline_config = pipeline_config valid_app_types = ['chat', 'agent', 'workflow'] - if ( - self.pipeline_config['ai']['dify-service-api']['app-type'] - not in valid_app_types - ): + if self.pipeline_config['ai']['dify-service-api']['app-type'] not in valid_app_types: raise errors.DifyAPIError( f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}' ) @@ -48,16 +45,10 @@ class DifyServiceAPIRunner(runner.RequestRunner): ): return resp_text - if ( - self.pipeline_config['ai']['dify-service-api']['thinking-convert'] - == 'original' - ): + if self.pipeline_config['ai']['dify-service-api']['thinking-convert'] == 'original': return resp_text - if ( - self.pipeline_config['ai']['dify-service-api']['thinking-convert'] - == 'remove' - ): + if self.pipeline_config['ai']['dify-service-api']['thinking-convert'] == 'remove': return re.sub( r'
Thinking... .*?
', '', @@ -65,18 +56,13 @@ class DifyServiceAPIRunner(runner.RequestRunner): flags=re.DOTALL, ) - if ( - self.pipeline_config['ai']['dify-service-api']['thinking-convert'] - == 'plain' - ): + if self.pipeline_config['ai']['dify-service-api']['thinking-convert'] == 'plain': pattern = r'
Thinking... (.*?)
' thinking_text = re.search(pattern, resp_text, flags=re.DOTALL) content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL) return f'{thinking_text.group(1)}\n{content_text}' - async def _preprocess_user_message( - self, query: core_entities.Query - ) -> tuple[str, list[str]]: + async def _preprocess_user_message(self, query: core_entities.Query) -> tuple[str, list[str]]: """预处理用户消息,提取纯文本,并将图片上传到 Dify 服务 Returns: @@ -90,9 +76,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): if ce.type == 'text': plain_text += ce.text elif ce.type == 'image_base64': - image_b64, image_format = await image.extract_b64_and_format( - ce.image_base64 - ) + image_b64, image_format = await image.extract_b64_and_format(ce.image_base64) file_bytes = base64.b64decode(image_b64) file = ('img.png', file_bytes, f'image/{image_format}') file_upload_resp = await self.dify_client.upload_file( @@ -106,9 +90,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): return plain_text, image_ids - async def _chat_messages( - self, query: core_entities.Query - ) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def _chat_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """调用聊天助手""" cov_id = query.session.using_conversation.uuid or '' @@ -151,9 +133,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): if chunk['data']['node_type'] == 'answer': yield llm_entities.Message( role='assistant', - content=self._try_convert_thinking( - chunk['data']['outputs']['answer'] - ), + content=self._try_convert_thinking(chunk['data']['outputs']['answer']), ) elif mode == 'basic': if chunk['event'] == 'message': @@ -166,9 +146,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): basic_mode_pending_chunk = '' if chunk is None: - raise errors.DifyAPIError( - 'Dify API 没有返回任何响应,请检查网络连接和API配置' - ) + raise errors.DifyAPIError('Dify API 没有返回任何响应,请检查网络连接和API配置') query.session.using_conversation.uuid = chunk['conversation_id'] @@ -217,9 +195,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): pending_agent_message += chunk['answer'] else: if pending_agent_message.strip() != '': - pending_agent_message = pending_agent_message.replace( - 'Action:', '' - ) + pending_agent_message = pending_agent_message.replace('Action:', '') yield llm_entities.Message( role='assistant', content=self._try_convert_thinking(pending_agent_message), @@ -227,9 +203,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): pending_agent_message = '' if chunk['event'] == 'agent_thought': - if ( - chunk['tool'] != '' and chunk['observation'] != '' - ): # 工具调用结果,跳过 + if chunk['tool'] != '' and chunk['observation'] != '': # 工具调用结果,跳过 continue if chunk['tool']: @@ -258,23 +232,17 @@ class DifyServiceAPIRunner(runner.RequestRunner): yield llm_entities.Message( role='assistant', - content=[ - llm_entities.ContentElement.from_image_url(image_url) - ], + content=[llm_entities.ContentElement.from_image_url(image_url)], ) if chunk['event'] == 'error': raise errors.DifyAPIError('dify 服务错误: ' + chunk['message']) if chunk is None: - raise errors.DifyAPIError( - 'Dify API 没有返回任何响应,请检查网络连接和API配置' - ) + raise errors.DifyAPIError('Dify API 没有返回任何响应,请检查网络连接和API配置') query.session.using_conversation.uuid = chunk['conversation_id'] - async def _workflow_messages( - self, query: core_entities.Query - ) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def _workflow_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """调用工作流""" if not query.session.using_conversation.uuid: @@ -315,10 +283,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): continue if chunk['event'] == 'node_started': - if ( - chunk['data']['node_type'] == 'start' - or chunk['data']['node_type'] == 'end' - ): + if chunk['data']['node_type'] == 'start' or chunk['data']['node_type'] == 'end': continue msg = llm_entities.Message( @@ -349,9 +314,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): yield msg - async def run( - self, query: core_entities.Query - ) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行请求""" if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat': async for msg in self._chat_messages(query): diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index 175f06db..7d5e04c5 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -12,15 +12,11 @@ from .. import entities as llm_entities class LocalAgentRunner(runner.RequestRunner): """本地Agent请求运行器""" - async def run( - self, query: core_entities.Query - ) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行请求""" pending_tool_calls = [] - req_messages = ( - query.prompt.messages.copy() + query.messages.copy() + [query.user_message] - ) + req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message] # 首次请求 msg = await query.use_llm_model.requester.invoke_llm( @@ -45,9 +41,7 @@ class LocalAgentRunner(runner.RequestRunner): parameters = json.loads(func.arguments) - func_ret = await self.ap.tool_mgr.execute_func_call( - query, func.name, parameters - ) + func_ret = await self.ap.tool_mgr.execute_func_call(query, func.name, parameters) msg = llm_entities.Message( role='tool', @@ -60,9 +54,7 @@ class LocalAgentRunner(runner.RequestRunner): req_messages.append(msg) except Exception as e: # 工具调用出错,添加一个报错信息到 req_messages - err_msg = llm_entities.Message( - role='tool', content=f'err: {e}', tool_call_id=tool_call.id - ) + err_msg = llm_entities.Message(role='tool', content=f'err: {e}', tool_call_id=tool_call.id) yield err_msg diff --git a/pkg/provider/session/sessionmgr.py b/pkg/provider/session/sessionmgr.py index a0a582ad..91bec826 100644 --- a/pkg/provider/session/sessionmgr.py +++ b/pkg/provider/session/sessionmgr.py @@ -23,10 +23,7 @@ class SessionManager: async def get_session(self, query: core_entities.Query) -> core_entities.Session: """获取会话""" for session in self.session_list: - if ( - query.launcher_type == session.launcher_type - and query.launcher_id == session.launcher_id - ): + if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id: return session session_concurrency = self.ap.instance_config.data['concurrency']['session'] diff --git a/pkg/provider/tools/loader.py b/pkg/provider/tools/loader.py index 25bb13eb..76b7d248 100644 --- a/pkg/provider/tools/loader.py +++ b/pkg/provider/tools/loader.py @@ -45,9 +45,7 @@ class ToolLoader(abc.ABC): pass @abc.abstractmethod - async def invoke_tool( - self, query: core_entities.Query, name: str, parameters: dict - ) -> typing.Any: + async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: """执行工具调用""" pass diff --git a/pkg/provider/tools/loaders/mcp.py b/pkg/provider/tools/loaders/mcp.py index 5377709f..5c030994 100644 --- a/pkg/provider/tools/loaders/mcp.py +++ b/pkg/provider/tools/loaders/mcp.py @@ -43,15 +43,11 @@ class RuntimeMCPSession: env=self.server_config['env'], ) - stdio_transport = await self.exit_stack.enter_async_context( - stdio_client(server_params) - ) + stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) stdio, write = stdio_transport - self.session = await self.exit_stack.enter_async_context( - ClientSession(stdio, write) - ) + self.session = await self.exit_stack.enter_async_context(ClientSession(stdio, write)) await self.session.initialize() @@ -66,25 +62,19 @@ class RuntimeMCPSession: sseio, write = sse_transport - self.session = await self.exit_stack.enter_async_context( - ClientSession(sseio, write) - ) + self.session = await self.exit_stack.enter_async_context(ClientSession(sseio, write)) await self.session.initialize() async def initialize(self): - self.ap.logger.debug( - f'初始化 MCP 会话: {self.server_name} {self.server_config}' - ) + self.ap.logger.debug(f'初始化 MCP 会话: {self.server_name} {self.server_config}') if self.server_config['mode'] == 'stdio': await self._init_stdio_python_server() elif self.server_config['mode'] == 'sse': await self._init_sse_server() else: - raise ValueError( - f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}' - ) + raise ValueError(f'无法识别 MCP 服务器类型: {self.server_name}: {self.server_config}') tools = await self.session.list_tools() @@ -132,9 +122,7 @@ class MCPLoader(loader.ToolLoader): self._last_listed_functions = [] async def initialize(self): - for server_config in self.ap.instance_config.data.get('mcp', {}).get( - 'servers', [] - ): + for server_config in self.ap.instance_config.data.get('mcp', {}).get('servers', []): if not server_config['enable']: continue session = RuntimeMCPSession(server_config['name'], server_config, self.ap) @@ -155,9 +143,7 @@ class MCPLoader(loader.ToolLoader): async def has_tool(self, name: str) -> bool: return name in [f.name for f in self._last_listed_functions] - async def invoke_tool( - self, query: core_entities.Query, name: str, parameters: dict - ) -> typing.Any: + async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: for server_name, session in self.sessions.items(): for function in session.functions: if function.name == name: diff --git a/pkg/provider/tools/loaders/plugin.py b/pkg/provider/tools/loaders/plugin.py index 2d7cda20..b7df2d67 100644 --- a/pkg/provider/tools/loaders/plugin.py +++ b/pkg/provider/tools/loaders/plugin.py @@ -48,9 +48,7 @@ class PluginToolLoader(loader.ToolLoader): return function, plugin.plugin_inst return None, None - async def invoke_tool( - self, query: core_entities.Query, name: str, parameters: dict - ) -> typing.Any: + async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: try: function, plugin = await self._get_function_and_plugin(name) if function is None: diff --git a/pkg/provider/tools/toolmgr.py b/pkg/provider/tools/toolmgr.py index 0f6fdac0..b1d43d08 100644 --- a/pkg/provider/tools/toolmgr.py +++ b/pkg/provider/tools/toolmgr.py @@ -28,9 +28,7 @@ class ToolManager: await loader_inst.initialize() self.loaders.append(loader_inst) - async def get_all_functions( - self, plugin_enabled: bool = None - ) -> list[entities.LLMFunction]: + async def get_all_functions(self, plugin_enabled: bool = None) -> list[entities.LLMFunction]: """获取所有函数""" all_functions: list[entities.LLMFunction] = [] @@ -39,9 +37,7 @@ class ToolManager: return all_functions - async def generate_tools_for_openai( - self, use_funcs: list[entities.LLMFunction] - ) -> list: + async def generate_tools_for_openai(self, use_funcs: list[entities.LLMFunction]) -> list: """生成函数列表""" tools = [] @@ -58,9 +54,7 @@ class ToolManager: return tools - async def generate_tools_for_anthropic( - self, use_funcs: list[entities.LLMFunction] - ) -> list: + async def generate_tools_for_anthropic(self, use_funcs: list[entities.LLMFunction]) -> list: """为anthropic生成函数列表 e.g. @@ -95,9 +89,7 @@ class ToolManager: return tools - async def execute_func_call( - self, query: core_entities.Query, name: str, parameters: dict - ) -> typing.Any: + async def execute_func_call(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: """执行函数调用""" for loader in self.loaders: diff --git a/pkg/utils/announce.py b/pkg/utils/announce.py index c78f078e..7108a08c 100644 --- a/pkg/utils/announce.py +++ b/pkg/utils/announce.py @@ -59,9 +59,7 @@ class AnnouncementManager: async def fetch_saved(self) -> list[Announcement]: if not os.path.exists('data/labels/announcement_saved.json'): - with open( - 'data/labels/announcement_saved.json', 'w', encoding='utf-8' - ) as f: + with open('data/labels/announcement_saved.json', 'w', encoding='utf-8') as f: f.write('[]') with open('data/labels/announcement_saved.json', 'r', encoding='utf-8') as f: @@ -74,11 +72,7 @@ class AnnouncementManager: async def write_saved(self, content: list[Announcement]): with open('data/labels/announcement_saved.json', 'w', encoding='utf-8') as f: - f.write( - json.dumps( - [item.to_dict() for item in content], indent=4, ensure_ascii=False - ) - ) + f.write(json.dumps([item.to_dict() for item in content], indent=4, ensure_ascii=False)) async def fetch_new(self) -> list[Announcement]: """获取新公告""" diff --git a/pkg/utils/image.py b/pkg/utils/image.py index b8100273..86230df8 100644 --- a/pkg/utils/image.py +++ b/pkg/utils/image.py @@ -57,9 +57,7 @@ async def get_gewechat_image_base64( ) as response: if response.status != 200: # print(response) - raise Exception( - f'获取gewechat图片下载失败: {await response.text()}' - ) + raise Exception(f'获取gewechat图片下载失败: {await response.text()}') resp_data = await response.json() if resp_data.get('ret') != 200: @@ -79,9 +77,7 @@ async def get_gewechat_image_base64( try: async with session.get(download_url) as img_response: if img_response.status != 200: - raise Exception( - f'下载图片失败: {await img_response.text()}, URL: {download_url}' - ) + raise Exception(f'下载图片失败: {await img_response.text()}, URL: {download_url}') image_data = await img_response.read() @@ -128,9 +124,7 @@ async def get_wecom_image_base64(pic_url: str) -> tuple[str, str]: return image_base64, image_format -async def get_qq_official_image_base64( - pic_url: str, content_type: str -) -> tuple[str, str]: +async def get_qq_official_image_base64(pic_url: str, content_type: str) -> tuple[str, str]: """ 下载QQ官方图片, 并且转换为base64格式 diff --git a/pkg/utils/importutil.py b/pkg/utils/importutil.py index 87ca652a..ad93e9f7 100644 --- a/pkg/utils/importutil.py +++ b/pkg/utils/importutil.py @@ -29,9 +29,7 @@ def import_dir(path: str): for file in os.listdir(path): if file.endswith('.py') and file != '__init__.py': full_path = os.path.join(path, file) - rel_path = full_path.replace( - os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '' - ) + rel_path = full_path.replace(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '') rel_path = rel_path[1:] rel_path = rel_path.replace('/', '.')[:-3] importlib.import_module(rel_path) diff --git a/pkg/utils/ip.py b/pkg/utils/ip.py index 56a12086..c67fe687 100644 --- a/pkg/utils/ip.py +++ b/pkg/utils/ip.py @@ -3,9 +3,7 @@ import aiohttp async def get_myip() -> str: try: - async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=10) - ) as session: + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=10)) as session: async with session.get('https://ip.useragentinfo.com/myip') as response: return await response.text() except Exception: diff --git a/pkg/utils/proxy.py b/pkg/utils/proxy.py index 4f3f7dec..04160082 100644 --- a/pkg/utils/proxy.py +++ b/pkg/utils/proxy.py @@ -23,20 +23,10 @@ class ProxyManager: 'https://': os.getenv('HTTPS_PROXY') or os.getenv('https_proxy'), } - if ( - 'http' in self.ap.instance_config.data['proxy'] - and self.ap.instance_config.data['proxy']['http'] - ): - self.forward_proxies['http://'] = self.ap.instance_config.data['proxy'][ - 'http' - ] - if ( - 'https' in self.ap.instance_config.data['proxy'] - and self.ap.instance_config.data['proxy']['https'] - ): - self.forward_proxies['https://'] = self.ap.instance_config.data['proxy'][ - 'https' - ] + if 'http' in self.ap.instance_config.data['proxy'] and self.ap.instance_config.data['proxy']['http']: + self.forward_proxies['http://'] = self.ap.instance_config.data['proxy']['http'] + if 'https' in self.ap.instance_config.data['proxy'] and self.ap.instance_config.data['proxy']['https']: + self.forward_proxies['https://'] = self.ap.instance_config.data['proxy']['https'] # 设置到环境变量 os.environ['HTTP_PROXY'] = self.forward_proxies['http://'] or '' diff --git a/pkg/utils/version.py b/pkg/utils/version.py index 0e0516eb..46c1aad6 100644 --- a/pkg/utils/version.py +++ b/pkg/utils/version.py @@ -60,18 +60,14 @@ class VersionManager: latest_rls = rls self.ap.logger.info('更新日志: {}'.format(rls_notes)) - if latest_rls == {} and not self.is_newer( - latest_tag_name, current_tag - ): # 没有新版本 + if latest_rls == {} and not self.is_newer(latest_tag_name, current_tag): # 没有新版本 return False # 下载最新版本的zip到temp目录 self.ap.logger.info('开始下载最新版本: {}'.format(latest_rls['zipball_url'])) zip_url = latest_rls['zipball_url'] - zip_resp = requests.get( - url=zip_url, proxies=self.ap.proxy_mgr.get_forward_proxies() - ) + zip_resp = requests.get(url=zip_url, proxies=self.ap.proxy_mgr.get_forward_proxies()) zip_data = zip_resp.content # 检查temp/updater目录 @@ -82,11 +78,7 @@ class VersionManager: with open('temp/updater/{}.zip'.format(latest_rls['tag_name']), 'wb') as f: f.write(zip_data) - self.ap.logger.info( - '下载最新版本完成: {}'.format( - 'temp/updater/{}.zip'.format(latest_rls['tag_name']) - ) - ) + self.ap.logger.info('下载最新版本完成: {}'.format('temp/updater/{}.zip'.format(latest_rls['tag_name']))) # 解压zip到temp/updater// import zipfile @@ -97,17 +89,13 @@ class VersionManager: shutil.rmtree('temp/updater/{}'.format(latest_rls['tag_name'])) os.mkdir('temp/updater/{}'.format(latest_rls['tag_name'])) - with zipfile.ZipFile( - 'temp/updater/{}.zip'.format(latest_rls['tag_name']), 'r' - ) as zip_ref: + with zipfile.ZipFile('temp/updater/{}.zip'.format(latest_rls['tag_name']), 'r') as zip_ref: zip_ref.extractall('temp/updater/{}'.format(latest_rls['tag_name'])) # 覆盖源码 source_root = '' # 找到temp/updater//中的第一个子目录路径 - for root, dirs, files in os.walk( - 'temp/updater/{}'.format(latest_rls['tag_name']) - ): + for root, dirs, files in os.walk('temp/updater/{}'.format(latest_rls['tag_name'])): if root != 'temp/updater/{}'.format(latest_rls['tag_name']): source_root = root break