mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
style: restrict line-length
This commit is contained in:
@@ -1,3 +1,6 @@
|
||||
|
||||
line-length = 120
|
||||
|
||||
[lint]
|
||||
|
||||
ignore = [
|
||||
|
||||
@@ -25,9 +25,7 @@ class DingTalkClient:
|
||||
self.secret = client_secret
|
||||
# 在 DingTalkClient 中传入自己作为参数,避免循环导入
|
||||
self.EchoTextHandler = EchoTextHandler(self)
|
||||
self.client.register_callback_handler(
|
||||
dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler
|
||||
)
|
||||
self.client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler)
|
||||
self._message_handlers = {
|
||||
'example': [],
|
||||
}
|
||||
@@ -86,9 +84,7 @@ class DingTalkClient:
|
||||
|
||||
if response.status_code == 200:
|
||||
file_bytes = response.content
|
||||
base64_str = base64.b64encode(file_bytes).decode(
|
||||
'utf-8'
|
||||
) # 返回字符串格式
|
||||
base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式
|
||||
return base64_str
|
||||
else:
|
||||
raise Exception('获取文件失败')
|
||||
@@ -151,9 +147,7 @@ class DingTalkClient:
|
||||
for handler in self._message_handlers[msg_type]:
|
||||
await handler(event)
|
||||
|
||||
async def get_message(
|
||||
self, incoming_message: dingtalk_stream.chatbot.ChatbotMessage
|
||||
):
|
||||
async def get_message(self, incoming_message: dingtalk_stream.chatbot.ChatbotMessage):
|
||||
try:
|
||||
# print(json.dumps(incoming_message.to_dict(), indent=4, ensure_ascii=False))
|
||||
message_data = {
|
||||
@@ -170,9 +164,7 @@ class DingTalkClient:
|
||||
if 'text' in item:
|
||||
message_data['Content'] = item['text']
|
||||
if incoming_message.get_image_list()[0]:
|
||||
message_data['Picture'] = await self.download_image(
|
||||
incoming_message.get_image_list()[0]
|
||||
)
|
||||
message_data['Picture'] = await self.download_image(incoming_message.get_image_list()[0])
|
||||
message_data['Type'] = 'text'
|
||||
|
||||
elif incoming_message.message_type == 'text':
|
||||
@@ -180,15 +172,11 @@ class DingTalkClient:
|
||||
|
||||
message_data['Type'] = 'text'
|
||||
elif incoming_message.message_type == 'picture':
|
||||
message_data['Picture'] = await self.download_image(
|
||||
incoming_message.get_image_list()[0]
|
||||
)
|
||||
message_data['Picture'] = await self.download_image(incoming_message.get_image_list()[0])
|
||||
|
||||
message_data['Type'] = 'image'
|
||||
elif incoming_message.message_type == 'audio':
|
||||
message_data['Audio'] = await self.get_audio_url(
|
||||
incoming_message.to_dict()['content']['downloadCode']
|
||||
)
|
||||
message_data['Audio'] = await self.get_audio_url(incoming_message.to_dict()['content']['downloadCode'])
|
||||
|
||||
message_data['Type'] = 'audio'
|
||||
|
||||
|
||||
@@ -68,9 +68,7 @@ class OAClient:
|
||||
elif request.method == 'POST':
|
||||
encryt_msg = await request.data
|
||||
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
|
||||
ret, xml_msg = wxcpt.DecryptMsg(
|
||||
encryt_msg, msg_signature, timestamp, nonce
|
||||
)
|
||||
ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce)
|
||||
xml_msg = xml_msg.decode('utf-8')
|
||||
|
||||
if ret != 0:
|
||||
@@ -112,9 +110,7 @@ class OAClient:
|
||||
# create_time=int(time.time()),
|
||||
# content = "请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。"
|
||||
# )
|
||||
print(
|
||||
'请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。'
|
||||
)
|
||||
print('请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。')
|
||||
return ''
|
||||
|
||||
except Exception:
|
||||
@@ -128,12 +124,8 @@ class OAClient:
|
||||
'FromUserName': root.find('FromUserName').text,
|
||||
'CreateTime': int(root.find('CreateTime').text),
|
||||
'MsgType': root.find('MsgType').text,
|
||||
'Content': root.find('Content').text
|
||||
if root.find('Content') is not None
|
||||
else None,
|
||||
'MsgId': int(root.find('MsgId').text)
|
||||
if root.find('MsgId') is not None
|
||||
else None,
|
||||
'Content': root.find('Content').text if root.find('Content') is not None else None,
|
||||
'MsgId': int(root.find('MsgId').text) if root.find('MsgId') is not None else None,
|
||||
}
|
||||
|
||||
return message_data
|
||||
@@ -225,9 +217,7 @@ class OAClientForLongerResponse:
|
||||
elif request.method == 'POST':
|
||||
encryt_msg = await request.data
|
||||
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
|
||||
ret, xml_msg = wxcpt.DecryptMsg(
|
||||
encryt_msg, msg_signature, timestamp, nonce
|
||||
)
|
||||
ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce)
|
||||
xml_msg = xml_msg.decode('utf-8')
|
||||
|
||||
if ret != 0:
|
||||
@@ -238,18 +228,12 @@ class OAClientForLongerResponse:
|
||||
from_user = root.find('FromUserName').text
|
||||
to_user = root.find('ToUserName').text
|
||||
|
||||
if (
|
||||
self.msg_queue.get(from_user)
|
||||
and self.msg_queue[from_user][0]['content']
|
||||
):
|
||||
if self.msg_queue.get(from_user) and self.msg_queue[from_user][0]['content']:
|
||||
queue_top = self.msg_queue[from_user].pop(0)
|
||||
queue_content = queue_top['content']
|
||||
|
||||
# 弹出用户消息
|
||||
if (
|
||||
self.user_msg_queue.get(from_user)
|
||||
and self.user_msg_queue[from_user]
|
||||
):
|
||||
if self.user_msg_queue.get(from_user) and self.user_msg_queue[from_user]:
|
||||
self.user_msg_queue[from_user].pop(0)
|
||||
|
||||
response_xml = xml_template.format(
|
||||
@@ -268,10 +252,7 @@ class OAClientForLongerResponse:
|
||||
content=self.loading_message,
|
||||
)
|
||||
|
||||
if (
|
||||
self.user_msg_queue.get(from_user)
|
||||
and self.user_msg_queue[from_user][0]['content']
|
||||
):
|
||||
if self.user_msg_queue.get(from_user) and self.user_msg_queue[from_user][0]['content']:
|
||||
return response_xml
|
||||
else:
|
||||
message_data = await self.get_message(xml_msg)
|
||||
@@ -299,12 +280,8 @@ class OAClientForLongerResponse:
|
||||
'FromUserName': root.find('FromUserName').text,
|
||||
'CreateTime': int(root.find('CreateTime').text),
|
||||
'MsgType': root.find('MsgType').text,
|
||||
'Content': root.find('Content').text
|
||||
if root.find('Content') is not None
|
||||
else None,
|
||||
'MsgId': int(root.find('MsgId').text)
|
||||
if root.find('MsgId') is not None
|
||||
else None,
|
||||
'Content': root.find('Content').text if root.find('Content') is not None else None,
|
||||
'MsgId': int(root.find('MsgId').text) if root.find('MsgId') is not None else None,
|
||||
}
|
||||
|
||||
return message_data
|
||||
|
||||
@@ -144,15 +144,9 @@ class QQOfficialClient:
|
||||
'group_openid': msg.get('d', {}).get('group_openid', {}),
|
||||
}
|
||||
attachments = msg.get('d', {}).get('attachments', [])
|
||||
image_attachments = [
|
||||
attachment['url']
|
||||
for attachment in attachments
|
||||
if await self.is_image(attachment)
|
||||
]
|
||||
image_attachments = [attachment['url'] for attachment in attachments if await self.is_image(attachment)]
|
||||
image_attachments_type = [
|
||||
attachment['content_type']
|
||||
for attachment in attachments
|
||||
if await self.is_image(attachment)
|
||||
attachment['content_type'] for attachment in attachments if await self.is_image(attachment)
|
||||
]
|
||||
if image_attachments:
|
||||
message_data['image_attachments'] = image_attachments[0]
|
||||
@@ -211,9 +205,7 @@ class QQOfficialClient:
|
||||
else:
|
||||
raise Exception(response.read().decode())
|
||||
|
||||
async def send_channle_group_text_msg(
|
||||
self, channel_id: str, content: str, msg_id: str
|
||||
):
|
||||
async def send_channle_group_text_msg(self, channel_id: str, content: str, msg_id: str):
|
||||
"""发送频道群聊消息"""
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
@@ -235,9 +227,7 @@ class QQOfficialClient:
|
||||
else:
|
||||
raise Exception(response)
|
||||
|
||||
async def send_channle_private_text_msg(
|
||||
self, guild_id: str, content: str, msg_id: str
|
||||
):
|
||||
async def send_channle_private_text_msg(self, guild_id: str, content: str, msg_id: str):
|
||||
"""发送频道私聊消息"""
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
|
||||
@@ -1,59 +1,57 @@
|
||||
import json
|
||||
from quart import Quart, jsonify,request
|
||||
from quart import Quart, jsonify, request
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from .slackevent import SlackEvent
|
||||
from typing import Callable, Dict, Any
|
||||
from pkg.platform.types import events as platform_events, message as platform_message
|
||||
|
||||
class SlackClient():
|
||||
|
||||
def __init__(self,bot_token:str,signing_secret:str):
|
||||
|
||||
self.bot_token = bot_token
|
||||
self.signing_secret = signing_secret
|
||||
self.app = Quart(__name__)
|
||||
self.client = AsyncWebClient(self.bot_token)
|
||||
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'])
|
||||
self._message_handlers = {
|
||||
"example":[],
|
||||
}
|
||||
self.bot_user_id = None # 避免机器人回复自己的消息
|
||||
|
||||
async def handle_callback_request(self):
|
||||
try:
|
||||
body = await request.get_data()
|
||||
data = json.loads(body)
|
||||
if 'type' in data:
|
||||
if data['type'] == 'url_verification':
|
||||
return data['challenge']
|
||||
|
||||
bot_user_id = data.get("event",{}).get("bot_id","")
|
||||
|
||||
if self.bot_user_id and bot_user_id == self.bot_user_id:
|
||||
return jsonify({'status': 'ok'})
|
||||
|
||||
|
||||
# 处理私信
|
||||
if data and data.get("event", {}).get("channel_type") in ["im"]:
|
||||
event = SlackEvent.from_payload(data)
|
||||
await self._handle_message(event)
|
||||
return jsonify({'status': 'ok'})
|
||||
|
||||
#处理群聊
|
||||
if data.get("event",{}).get("type") == 'app_mention':
|
||||
data.setdefault("event", {})["channel_type"] = "channel"
|
||||
event = SlackEvent.from_payload(data)
|
||||
await self._handle_message(event)
|
||||
return jsonify({'status':'ok'})
|
||||
|
||||
return jsonify({'status': 'ok'})
|
||||
|
||||
except Exception as e:
|
||||
raise(e)
|
||||
|
||||
from typing import Callable
|
||||
from pkg.platform.types import events as platform_events
|
||||
|
||||
|
||||
async def _handle_message(self, event: SlackEvent):
|
||||
class SlackClient:
|
||||
def __init__(self, bot_token: str, signing_secret: str):
|
||||
self.bot_token = bot_token
|
||||
self.signing_secret = signing_secret
|
||||
self.app = Quart(__name__)
|
||||
self.client = AsyncWebClient(self.bot_token)
|
||||
self.app.add_url_rule(
|
||||
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']
|
||||
)
|
||||
self._message_handlers = {
|
||||
'example': [],
|
||||
}
|
||||
self.bot_user_id = None # 避免机器人回复自己的消息
|
||||
|
||||
async def handle_callback_request(self):
|
||||
try:
|
||||
body = await request.get_data()
|
||||
data = json.loads(body)
|
||||
if 'type' in data:
|
||||
if data['type'] == 'url_verification':
|
||||
return data['challenge']
|
||||
|
||||
bot_user_id = data.get('event', {}).get('bot_id', '')
|
||||
|
||||
if self.bot_user_id and bot_user_id == self.bot_user_id:
|
||||
return jsonify({'status': 'ok'})
|
||||
|
||||
# 处理私信
|
||||
if data and data.get('event', {}).get('channel_type') in ['im']:
|
||||
event = SlackEvent.from_payload(data)
|
||||
await self._handle_message(event)
|
||||
return jsonify({'status': 'ok'})
|
||||
|
||||
# 处理群聊
|
||||
if data.get('event', {}).get('type') == 'app_mention':
|
||||
data.setdefault('event', {})['channel_type'] = 'channel'
|
||||
event = SlackEvent.from_payload(data)
|
||||
await self._handle_message(event)
|
||||
return jsonify({'status': 'ok'})
|
||||
|
||||
return jsonify({'status': 'ok'})
|
||||
|
||||
except Exception as e:
|
||||
raise (e)
|
||||
|
||||
async def _handle_message(self, event: SlackEvent):
|
||||
"""
|
||||
处理消息事件。
|
||||
"""
|
||||
@@ -62,50 +60,38 @@ class SlackClient():
|
||||
for handler in self._message_handlers[msg_type]:
|
||||
await handler(event)
|
||||
|
||||
def on_message(self, msg_type: str):
|
||||
def on_message(self, msg_type: str):
|
||||
"""注册消息类型处理器"""
|
||||
|
||||
def decorator(func: Callable[[platform_events.Event], None]):
|
||||
if msg_type not in self._message_handlers:
|
||||
self._message_handlers[msg_type] = []
|
||||
self._message_handlers[msg_type].append(func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def send_message_to_channel(self,text:str,channel_id:str):
|
||||
try:
|
||||
response = await self.client.chat_postMessage(
|
||||
channel=channel_id,
|
||||
text=text
|
||||
)
|
||||
if self.bot_user_id is None and response.get("ok"):
|
||||
self.bot_user_id = response["message"]["bot_id"]
|
||||
return
|
||||
except Exception as e:
|
||||
raise e
|
||||
async def send_message_to_channel(self, text: str, channel_id: str):
|
||||
try:
|
||||
response = await self.client.chat_postMessage(channel=channel_id, text=text)
|
||||
if self.bot_user_id is None and response.get('ok'):
|
||||
self.bot_user_id = response['message']['bot_id']
|
||||
return
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def send_message_to_one(self,text:str,user_id:str):
|
||||
try:
|
||||
response = await self.client.chat_postMessage(
|
||||
channel = '@'+user_id,
|
||||
text= text
|
||||
)
|
||||
if self.bot_user_id is None and response.get("ok"):
|
||||
self.bot_user_id = response["message"]["bot_id"]
|
||||
|
||||
return
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||
"""
|
||||
启动 Quart 应用。
|
||||
"""
|
||||
await self.app.run_task(host=host, port=port, *args, **kwargs)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
async def send_message_to_one(self, text: str, user_id: str):
|
||||
try:
|
||||
response = await self.client.chat_postMessage(channel='@' + user_id, text=text)
|
||||
if self.bot_user_id is None and response.get('ok'):
|
||||
self.bot_user_id = response['message']['bot_id']
|
||||
|
||||
return
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||
"""
|
||||
启动 Quart 应用。
|
||||
"""
|
||||
await self.app.run_task(host=host, port=port, *args, **kwargs)
|
||||
|
||||
@@ -1,86 +1,82 @@
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
class SlackEvent(dict):
|
||||
@staticmethod
|
||||
def from_payload(payload: Dict[str, Any]) -> Optional["SlackEvent"]:
|
||||
def from_payload(payload: Dict[str, Any]) -> Optional['SlackEvent']:
|
||||
try:
|
||||
event = SlackEvent(payload)
|
||||
return event
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
|
||||
if self.get("event", {}).get("channel_type") == "im":
|
||||
blocks = self.get("event", {}).get("blocks", [])
|
||||
if not blocks:
|
||||
return ""
|
||||
if self.get('event', {}).get('channel_type') == 'im':
|
||||
blocks = self.get('event', {}).get('blocks', [])
|
||||
if not blocks:
|
||||
return ''
|
||||
|
||||
elements = blocks[0].get("elements", [])
|
||||
if not elements:
|
||||
return ""
|
||||
elements = blocks[0].get('elements', [])
|
||||
if not elements:
|
||||
return ''
|
||||
|
||||
elements = elements[0].get("elements", [])
|
||||
text = ""
|
||||
elements = elements[0].get('elements', [])
|
||||
text = ''
|
||||
|
||||
for el in elements:
|
||||
if el.get("type") == "text":
|
||||
text += el.get("text", "")
|
||||
elif el.get("type") == "link":
|
||||
text += el.get("url", "")
|
||||
for el in elements:
|
||||
if el.get('type') == 'text':
|
||||
text += el.get('text', '')
|
||||
elif el.get('type') == 'link':
|
||||
text += el.get('url', '')
|
||||
|
||||
return text
|
||||
return text
|
||||
|
||||
|
||||
if self.get("event",{}).get("channel_type") == 'channel':
|
||||
message_text = ""
|
||||
for block in self.get("event", {}).get("blocks", []):
|
||||
if block.get("type") == "rich_text":
|
||||
for element in block.get("elements", []):
|
||||
if element.get("type") == "rich_text_section":
|
||||
if self.get('event', {}).get('channel_type') == 'channel':
|
||||
message_text = ''
|
||||
for block in self.get('event', {}).get('blocks', []):
|
||||
if block.get('type') == 'rich_text':
|
||||
for element in block.get('elements', []):
|
||||
if element.get('type') == 'rich_text_section':
|
||||
parts = []
|
||||
for el in element.get("elements", []):
|
||||
if el.get("type") == "text":
|
||||
parts.append(el["text"])
|
||||
elif el.get("type") == "link":
|
||||
parts.append(el["url"])
|
||||
message_text = "".join(parts)
|
||||
|
||||
for el in element.get('elements', []):
|
||||
if el.get('type') == 'text':
|
||||
parts.append(el['text'])
|
||||
elif el.get('type') == 'link':
|
||||
parts.append(el['url'])
|
||||
message_text = ''.join(parts)
|
||||
|
||||
return message_text
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def user_id(self) -> Optional[str]:
|
||||
return self.get("event", {}).get("user","")
|
||||
|
||||
return self.get('event', {}).get('user', '')
|
||||
|
||||
@property
|
||||
def channel_id(self) -> Optional[str]:
|
||||
return self.get("event", {}).get("channel","")
|
||||
|
||||
return self.get('event', {}).get('channel', '')
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
""" message对应私聊,app_mention对应频道at """
|
||||
return self.get("event", {}).get("channel_type", "")
|
||||
|
||||
"""message对应私聊,app_mention对应频道at"""
|
||||
return self.get('event', {}).get('channel_type', '')
|
||||
|
||||
@property
|
||||
def message_id(self) -> str:
|
||||
return self.get("event_id","")
|
||||
|
||||
return self.get('event_id', '')
|
||||
|
||||
@property
|
||||
def pic_url(self) -> str:
|
||||
"""提取 Slack 事件中的图片 URL"""
|
||||
files = self.get("event", {}).get("files", [])
|
||||
files = self.get('event', {}).get('files', [])
|
||||
if files:
|
||||
return files[0].get("url_private", "")
|
||||
return files[0].get('url_private', '')
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def sender_name(self) -> str:
|
||||
return self.get("event", {}).get("user","")
|
||||
|
||||
return self.get('event', {}).get('user', '')
|
||||
|
||||
def __getattr__(self, key: str) -> Optional[Any]:
|
||||
return self.get(key)
|
||||
|
||||
@@ -88,4 +84,4 @@ class SlackEvent(dict):
|
||||
self[key] = value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<SlackEvent {super().__repr__()}>"
|
||||
return f'<SlackEvent {super().__repr__()}>'
|
||||
|
||||
@@ -147,12 +147,7 @@ class Prpcrypt(object):
|
||||
"""
|
||||
# 16位随机字符串添加到明文开头
|
||||
text = text.encode()
|
||||
text = (
|
||||
self.get_random_str()
|
||||
+ struct.pack('I', socket.htonl(len(text)))
|
||||
+ text
|
||||
+ receiveid.encode()
|
||||
)
|
||||
text = self.get_random_str() + struct.pack('I', socket.htonl(len(text))) + text + receiveid.encode()
|
||||
|
||||
# 使用自定义的填充方式对明文进行补位填充
|
||||
pkcs7 = PKCS7Encoder()
|
||||
|
||||
@@ -45,9 +45,7 @@ class WecomClient:
|
||||
return bool(self.access_token and self.access_token.strip())
|
||||
|
||||
async def check_access_token_for_contacts(self):
|
||||
return bool(
|
||||
self.access_token_for_contacts and self.access_token_for_contacts.strip()
|
||||
)
|
||||
return bool(self.access_token_for_contacts and self.access_token_for_contacts.strip())
|
||||
|
||||
async def get_access_token(self, secret):
|
||||
url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corpid}&corpsecret={secret}'
|
||||
@@ -61,15 +59,9 @@ class WecomClient:
|
||||
|
||||
async def get_users(self):
|
||||
if not self.check_access_token_for_contacts():
|
||||
self.access_token_for_contacts = await self.get_access_token(
|
||||
self.secret_for_contacts
|
||||
)
|
||||
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
|
||||
|
||||
url = (
|
||||
self.base_url
|
||||
+ '/user/list_id?access_token='
|
||||
+ self.access_token_for_contacts
|
||||
)
|
||||
url = self.base_url + '/user/list_id?access_token=' + self.access_token_for_contacts
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = {
|
||||
'cursor': '',
|
||||
@@ -88,15 +80,9 @@ class WecomClient:
|
||||
|
||||
async def send_to_all(self, content: str, agent_id: int):
|
||||
if not self.check_access_token_for_contacts():
|
||||
self.access_token_for_contacts = await self.get_access_token(
|
||||
self.secret_for_contacts
|
||||
)
|
||||
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
|
||||
|
||||
url = (
|
||||
self.base_url
|
||||
+ '/message/send?access_token='
|
||||
+ self.access_token_for_contacts
|
||||
)
|
||||
url = self.base_url + '/message/send?access_token=' + self.access_token_for_contacts
|
||||
user_ids = await self.get_users()
|
||||
user_ids_string = '|'.join(user_ids)
|
||||
async with httpx.AsyncClient() as client:
|
||||
@@ -187,27 +173,21 @@ class WecomClient:
|
||||
|
||||
if request.method == 'GET':
|
||||
echostr = request.args.get('echostr')
|
||||
ret, reply_echo_str = self.wxcpt.VerifyURL(
|
||||
msg_signature, timestamp, nonce, echostr
|
||||
)
|
||||
ret, reply_echo_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
|
||||
if ret != 0:
|
||||
raise Exception(f'验证失败,错误码: {ret}')
|
||||
return reply_echo_str
|
||||
|
||||
elif request.method == 'POST':
|
||||
encrypt_msg = await request.data
|
||||
ret, xml_msg = self.wxcpt.DecryptMsg(
|
||||
encrypt_msg, msg_signature, timestamp, nonce
|
||||
)
|
||||
ret, xml_msg = self.wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
|
||||
if ret != 0:
|
||||
raise Exception(f'消息解密失败,错误码: {ret}')
|
||||
|
||||
# 解析消息并处理
|
||||
message_data = await self.get_message(xml_msg)
|
||||
if message_data:
|
||||
event = WecomEvent.from_payload(
|
||||
message_data
|
||||
) # 转换为 WecomEvent 对象
|
||||
event = WecomEvent.from_payload(message_data) # 转换为 WecomEvent 对象
|
||||
if event:
|
||||
await self._handle_message(event)
|
||||
|
||||
@@ -253,23 +233,13 @@ class WecomClient:
|
||||
'FromUserName': root.find('FromUserName').text,
|
||||
'CreateTime': int(root.find('CreateTime').text),
|
||||
'MsgType': root.find('MsgType').text,
|
||||
'Content': root.find('Content').text
|
||||
if root.find('Content') is not None
|
||||
else None,
|
||||
'MsgId': int(root.find('MsgId').text)
|
||||
if root.find('MsgId') is not None
|
||||
else None,
|
||||
'AgentID': int(root.find('AgentID').text)
|
||||
if root.find('AgentID') is not None
|
||||
else None,
|
||||
'Content': root.find('Content').text if root.find('Content') is not None else None,
|
||||
'MsgId': int(root.find('MsgId').text) if root.find('MsgId') is not None else None,
|
||||
'AgentID': int(root.find('AgentID').text) if root.find('AgentID') is not None else None,
|
||||
}
|
||||
if message_data['MsgType'] == 'image':
|
||||
message_data['MediaId'] = (
|
||||
root.find('MediaId').text if root.find('MediaId') is not None else None
|
||||
)
|
||||
message_data['PicUrl'] = (
|
||||
root.find('PicUrl').text if root.find('PicUrl') is not None else None
|
||||
)
|
||||
message_data['MediaId'] = root.find('MediaId').text if root.find('MediaId') is not None else None
|
||||
message_data['PicUrl'] = root.find('PicUrl').text if root.find('PicUrl') is not None else None
|
||||
|
||||
return message_data
|
||||
|
||||
@@ -298,12 +268,7 @@ class WecomClient:
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
|
||||
url = (
|
||||
self.base_url
|
||||
+ '/media/upload?access_token='
|
||||
+ self.access_token
|
||||
+ '&type=file'
|
||||
)
|
||||
url = self.base_url + '/media/upload?access_token=' + self.access_token + '&type=file'
|
||||
file_bytes = None
|
||||
file_name = 'uploaded_file.txt'
|
||||
|
||||
|
||||
@@ -6,60 +6,61 @@ import httpx
|
||||
import traceback
|
||||
from quart import Quart
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Callable, Dict, Any
|
||||
from typing import Callable
|
||||
from .wecomcsevent import WecomCSEvent
|
||||
from pkg.platform.types import events as platform_events, message as platform_message
|
||||
from pkg.platform.types import message as platform_message
|
||||
import aiofiles
|
||||
|
||||
|
||||
class WecomCSClient():
|
||||
def __init__(self,corpid:str,secret:str,token:str,EncodingAESKey:str):
|
||||
class WecomCSClient:
|
||||
def __init__(self, corpid: str, secret: str, token: str, EncodingAESKey: str):
|
||||
self.corpid = corpid
|
||||
self.secret = secret
|
||||
self.access_token_for_contacts =''
|
||||
self.access_token_for_contacts = ''
|
||||
self.token = token
|
||||
self.aes = EncodingAESKey
|
||||
self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin'
|
||||
self.access_token = ''
|
||||
self.app = Quart(__name__)
|
||||
self.wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid)
|
||||
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'])
|
||||
self.app.add_url_rule(
|
||||
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']
|
||||
)
|
||||
self._message_handlers = {
|
||||
"example":[],
|
||||
'example': [],
|
||||
}
|
||||
|
||||
async def get_pic_url(self, media_id: str):
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
|
||||
url = f"{self.base_url}/media/get?access_token={self.access_token}&media_id={media_id}"
|
||||
url = f'{self.base_url}/media/get?access_token={self.access_token}&media_id={media_id}'
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
if response.headers.get("Content-Type", "").startswith("application/json"):
|
||||
if response.headers.get('Content-Type', '').startswith('application/json'):
|
||||
data = response.json()
|
||||
if data.get('errcode') in [40014, 42001]:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
return await self.get_pic_url(media_id)
|
||||
else:
|
||||
raise Exception("Failed to get image: " + str(data))
|
||||
raise Exception('Failed to get image: ' + str(data))
|
||||
|
||||
# 否则是图片,转成 base64
|
||||
image_bytes = response.content
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
base64_str = base64.b64encode(image_bytes).decode("utf-8")
|
||||
base64_str = f"data:{content_type};base64,{base64_str}"
|
||||
content_type = response.headers.get('Content-Type', '')
|
||||
base64_str = base64.b64encode(image_bytes).decode('utf-8')
|
||||
base64_str = f'data:{content_type};base64,{base64_str}'
|
||||
return base64_str
|
||||
|
||||
|
||||
#access——token操作
|
||||
# access——token操作
|
||||
async def check_access_token(self):
|
||||
return bool(self.access_token and self.access_token.strip())
|
||||
|
||||
async def check_access_token_for_contacts(self):
|
||||
return bool(self.access_token_for_contacts and self.access_token_for_contacts.strip())
|
||||
|
||||
async def get_access_token(self,secret):
|
||||
async def get_access_token(self, secret):
|
||||
url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corpid}&corpsecret={secret}'
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
@@ -67,118 +68,115 @@ class WecomCSClient():
|
||||
if 'access_token' in data:
|
||||
return data['access_token']
|
||||
else:
|
||||
raise Exception(f"未获取access token: {data}")
|
||||
|
||||
async def get_detailed_message_list(self,xml_msg:str):
|
||||
raise Exception(f'未获取access token: {data}')
|
||||
|
||||
async def get_detailed_message_list(self, xml_msg: str):
|
||||
# 在本方法中解析消息,并且获得消息的具体内容
|
||||
if isinstance(xml_msg, bytes):
|
||||
xml_msg = xml_msg.decode('utf-8')
|
||||
root = ET.fromstring(xml_msg)
|
||||
token = root.find("Token").text
|
||||
open_kfid = root.find("OpenKfId").text
|
||||
|
||||
token = root.find('Token').text
|
||||
open_kfid = root.find('OpenKfId').text
|
||||
|
||||
# if open_kfid in self.openkfid_list:
|
||||
# return None
|
||||
# else:
|
||||
# self.openkfid_list.append(open_kfid)
|
||||
|
||||
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
|
||||
url = self.base_url+'/kf/sync_msg?access_token='+ self.access_token
|
||||
|
||||
url = self.base_url + '/kf/sync_msg?access_token=' + self.access_token
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = {
|
||||
"token": token,
|
||||
"voice_format": 0,
|
||||
"open_kfid": open_kfid,
|
||||
'token': token,
|
||||
'voice_format': 0,
|
||||
'open_kfid': open_kfid,
|
||||
}
|
||||
response = await client.post(url,json=params)
|
||||
response = await client.post(url, json=params)
|
||||
data = response.json()
|
||||
if data['errcode'] == 40014 or data['errcode'] == 42001:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
return await self.get_detailed_message_list(xml_msg)
|
||||
if data['errcode'] != 0:
|
||||
raise Exception("Failed to get message")
|
||||
|
||||
raise Exception('Failed to get message')
|
||||
|
||||
last_msg_data = data['msg_list'][-1]
|
||||
open_kfid = last_msg_data.get("open_kfid")
|
||||
open_kfid = last_msg_data.get('open_kfid')
|
||||
# 进行获取图片操作
|
||||
if last_msg_data.get("msgtype") == "image":
|
||||
media_id = last_msg_data.get("image").get("media_id")
|
||||
if last_msg_data.get('msgtype') == 'image':
|
||||
media_id = last_msg_data.get('image').get('media_id')
|
||||
picurl = await self.get_pic_url(media_id)
|
||||
last_msg_data["picurl"] = picurl
|
||||
last_msg_data['picurl'] = picurl
|
||||
# await self.change_service_status(userid=external_userid,openkfid=open_kfid,servicer=servicer)
|
||||
return last_msg_data
|
||||
|
||||
|
||||
async def change_service_status(self,userid:str,openkfid:str,servicer:str):
|
||||
|
||||
async def change_service_status(self, userid: str, openkfid: str, servicer: str):
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
url = self.base_url+"/kf/service_state/get?access_token="+self.access_token
|
||||
url = self.base_url + '/kf/service_state/get?access_token=' + self.access_token
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = {
|
||||
"open_kfid" : openkfid,
|
||||
"external_userid" : userid,
|
||||
"service_state" : 1,
|
||||
"servicer_userid" : servicer,
|
||||
'open_kfid': openkfid,
|
||||
'external_userid': userid,
|
||||
'service_state': 1,
|
||||
'servicer_userid': servicer,
|
||||
}
|
||||
response = await client.post(url,json=params)
|
||||
response = await client.post(url, json=params)
|
||||
data = response.json()
|
||||
if data['errcode'] == 40014 or data['errcode'] == 42001:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
return await self.change_service_status(userid,openkfid)
|
||||
return await self.change_service_status(userid, openkfid)
|
||||
if data['errcode'] != 0:
|
||||
raise Exception("Failed to change service status: "+str(data))
|
||||
|
||||
raise Exception('Failed to change service status: ' + str(data))
|
||||
|
||||
async def send_image(self,user_id:str,agent_id:int,media_id:str):
|
||||
async def send_image(self, user_id: str, agent_id: int, media_id: str):
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
url = self.base_url+'/media/upload?access_token='+self.access_token
|
||||
url = self.base_url + '/media/upload?access_token=' + self.access_token
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = {
|
||||
"touser" : user_id,
|
||||
"toparty" : "",
|
||||
"totag":"",
|
||||
"agentid" : agent_id,
|
||||
"msgtype" : "image",
|
||||
"image" : {
|
||||
"media_id" : media_id,
|
||||
'touser': user_id,
|
||||
'toparty': '',
|
||||
'totag': '',
|
||||
'agentid': agent_id,
|
||||
'msgtype': 'image',
|
||||
'image': {
|
||||
'media_id': media_id,
|
||||
},
|
||||
"safe":0,
|
||||
"enable_id_trans": 0,
|
||||
"enable_duplicate_check": 0,
|
||||
"duplicate_check_interval": 1800
|
||||
'safe': 0,
|
||||
'enable_id_trans': 0,
|
||||
'enable_duplicate_check': 0,
|
||||
'duplicate_check_interval': 1800,
|
||||
}
|
||||
try:
|
||||
response = await client.post(url,json=params)
|
||||
response = await client.post(url, json=params)
|
||||
data = response.json()
|
||||
except Exception as e:
|
||||
raise Exception("Failed to send image: "+str(e))
|
||||
raise Exception('Failed to send image: ' + str(e))
|
||||
|
||||
# 企业微信错误码40014和42001,代表accesstoken问题
|
||||
if data['errcode'] == 40014 or data['errcode'] == 42001:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
return await self.send_image(user_id,agent_id,media_id)
|
||||
return await self.send_image(user_id, agent_id, media_id)
|
||||
|
||||
if data['errcode'] != 0:
|
||||
raise Exception("Failed to send image: "+str(data))
|
||||
|
||||
raise Exception('Failed to send image: ' + str(data))
|
||||
|
||||
async def send_text_msg(self, open_kfid: str, external_userid: str, msgid: str,content:str):
|
||||
async def send_text_msg(self, open_kfid: str, external_userid: str, msgid: str, content: str):
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
|
||||
url = f"https://qyapi.weixin.qq.com/cgi-bin/kf/send_msg?access_token={self.access_token}"
|
||||
url = f'https://qyapi.weixin.qq.com/cgi-bin/kf/send_msg?access_token={self.access_token}'
|
||||
|
||||
payload = {
|
||||
"touser": external_userid,
|
||||
"open_kfid": open_kfid,
|
||||
"msgid": msgid,
|
||||
"msgtype": "text",
|
||||
"text": {
|
||||
"content": content,
|
||||
}
|
||||
'touser': external_userid,
|
||||
'open_kfid': open_kfid,
|
||||
'msgid': msgid,
|
||||
'msgtype': 'text',
|
||||
'text': {
|
||||
'content': content,
|
||||
},
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
@@ -187,46 +185,44 @@ class WecomCSClient():
|
||||
data = response.json()
|
||||
if data['errcode'] == 40014 or data['errcode'] == 42001:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
return await self.send_text_msg(open_kfid,external_userid,msgid,content)
|
||||
return await self.send_text_msg(open_kfid, external_userid, msgid, content)
|
||||
if data['errcode'] != 0:
|
||||
raise Exception("Failed to send message")
|
||||
raise Exception('Failed to send message')
|
||||
return data
|
||||
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""
|
||||
处理回调请求,包括 GET 验证和 POST 消息接收。
|
||||
"""
|
||||
try:
|
||||
msg_signature = request.args.get('msg_signature')
|
||||
timestamp = request.args.get('timestamp')
|
||||
nonce = request.args.get('nonce')
|
||||
|
||||
msg_signature = request.args.get("msg_signature")
|
||||
timestamp = request.args.get("timestamp")
|
||||
nonce = request.args.get("nonce")
|
||||
|
||||
if request.method == "GET":
|
||||
echostr = request.args.get("echostr")
|
||||
if request.method == 'GET':
|
||||
echostr = request.args.get('echostr')
|
||||
ret, reply_echo_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
|
||||
if ret != 0:
|
||||
raise Exception(f"验证失败,错误码: {ret}")
|
||||
raise Exception(f'验证失败,错误码: {ret}')
|
||||
return reply_echo_str
|
||||
|
||||
elif request.method == "POST":
|
||||
elif request.method == 'POST':
|
||||
encrypt_msg = await request.data
|
||||
ret, xml_msg = self.wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
|
||||
if ret != 0:
|
||||
raise Exception(f"消息解密失败,错误码: {ret}")
|
||||
raise Exception(f'消息解密失败,错误码: {ret}')
|
||||
|
||||
# 解析消息并处理
|
||||
message_data = await self.get_detailed_message_list(xml_msg)
|
||||
if message_data is not None:
|
||||
event = WecomCSEvent.from_payload(message_data)
|
||||
event = WecomCSEvent.from_payload(message_data)
|
||||
if event:
|
||||
await self._handle_message(event)
|
||||
|
||||
return "success"
|
||||
return 'success'
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return f"Error processing request: {str(e)}", 400
|
||||
return f'Error processing request: {str(e)}', 400
|
||||
|
||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||
"""
|
||||
@@ -238,11 +234,13 @@ class WecomCSClient():
|
||||
"""
|
||||
注册消息类型处理器。
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[[WecomCSEvent], None]):
|
||||
if msg_type not in self._message_handlers:
|
||||
self._message_handlers[msg_type] = []
|
||||
self._message_handlers[msg_type].append(func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def _handle_message(self, event: WecomCSEvent):
|
||||
@@ -254,25 +252,23 @@ class WecomCSClient():
|
||||
for handler in self._message_handlers[msg_type]:
|
||||
await handler(event)
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def get_image_type(image_bytes: bytes) -> str:
|
||||
"""
|
||||
通过图片的magic numbers判断图片类型
|
||||
"""
|
||||
magic_numbers = {
|
||||
b'\xFF\xD8\xFF': 'jpg',
|
||||
b'\x89\x50\x4E\x47': 'png',
|
||||
b'\xff\xd8\xff': 'jpg',
|
||||
b'\x89\x50\x4e\x47': 'png',
|
||||
b'\x47\x49\x46': 'gif',
|
||||
b'\x42\x4D': 'bmp',
|
||||
b'\x00\x00\x01\x00': 'ico'
|
||||
b'\x42\x4d': 'bmp',
|
||||
b'\x00\x00\x01\x00': 'ico',
|
||||
}
|
||||
|
||||
|
||||
for magic, ext in magic_numbers.items():
|
||||
if image_bytes.startswith(magic):
|
||||
return ext
|
||||
return 'jpg' # 默认返回jpg
|
||||
|
||||
|
||||
async def upload_to_work(self, image: platform_message.Image):
|
||||
"""
|
||||
@@ -283,7 +279,7 @@ class WecomCSClient():
|
||||
|
||||
url = self.base_url + '/media/upload?access_token=' + self.access_token + '&type=file'
|
||||
file_bytes = None
|
||||
file_name = "uploaded_file.txt"
|
||||
file_name = 'uploaded_file.txt'
|
||||
|
||||
# 获取文件的二进制数据
|
||||
if image.path:
|
||||
@@ -302,20 +298,22 @@ class WecomCSClient():
|
||||
padded_base64 = base64_data + '=' * padding
|
||||
file_bytes = base64.b64decode(padded_base64)
|
||||
except binascii.Error as e:
|
||||
raise ValueError(f"Invalid base64 string: {str(e)}")
|
||||
raise ValueError(f'Invalid base64 string: {str(e)}')
|
||||
else:
|
||||
raise ValueError("image对象出错")
|
||||
raise ValueError('image对象出错')
|
||||
|
||||
# 设置 multipart/form-data 格式的文件
|
||||
boundary = "-------------------------acebdf13572468"
|
||||
headers = {
|
||||
'Content-Type': f'multipart/form-data; boundary={boundary}'
|
||||
}
|
||||
boundary = '-------------------------acebdf13572468'
|
||||
headers = {'Content-Type': f'multipart/form-data; boundary={boundary}'}
|
||||
body = (
|
||||
f"--{boundary}\r\n"
|
||||
f"Content-Disposition: form-data; name=\"media\"; filename=\"{file_name}\"; filelength={len(file_bytes)}\r\n"
|
||||
f"Content-Type: application/octet-stream\r\n\r\n"
|
||||
).encode('utf-8') + file_bytes + f"\r\n--{boundary}--\r\n".encode('utf-8')
|
||||
(
|
||||
f'--{boundary}\r\n'
|
||||
f'Content-Disposition: form-data; name="media"; filename="{file_name}"; filelength={len(file_bytes)}\r\n'
|
||||
f'Content-Type: application/octet-stream\r\n\r\n'
|
||||
).encode('utf-8')
|
||||
+ file_bytes
|
||||
+ f'\r\n--{boundary}--\r\n'.encode('utf-8')
|
||||
)
|
||||
|
||||
# 上传文件
|
||||
async with httpx.AsyncClient() as client:
|
||||
@@ -325,19 +323,18 @@ class WecomCSClient():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
media_id = await self.upload_to_work(image)
|
||||
if data.get('errcode', 0) != 0:
|
||||
raise Exception("failed to upload file")
|
||||
raise Exception('failed to upload file')
|
||||
|
||||
media_id = data.get('media_id')
|
||||
return media_id
|
||||
|
||||
async def download_image_to_bytes(self,url:str) -> bytes:
|
||||
async def download_image_to_bytes(self, url: str) -> bytes:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
#进行media_id的获取
|
||||
# 进行media_id的获取
|
||||
async def get_media_id(self, image: platform_message.Image):
|
||||
|
||||
media_id = await self.upload_to_work(image=image)
|
||||
return media_id
|
||||
|
||||
@@ -9,7 +9,7 @@ class WecomCSEvent(dict):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_payload(payload: Dict[str, Any]) -> Optional["WecomCSEvent"]:
|
||||
def from_payload(payload: Dict[str, Any]) -> Optional['WecomCSEvent']:
|
||||
"""
|
||||
从企业微信(客服会话)事件数据构造 `WecomEvent` 对象。
|
||||
|
||||
@@ -21,7 +21,7 @@ class WecomCSEvent(dict):
|
||||
"""
|
||||
try:
|
||||
event = WecomCSEvent(payload)
|
||||
_ = event.type,
|
||||
_ = (event.type,)
|
||||
return event
|
||||
except KeyError:
|
||||
return None
|
||||
@@ -34,8 +34,8 @@ class WecomCSEvent(dict):
|
||||
Returns:
|
||||
str: 事件类型。
|
||||
"""
|
||||
return self.get("msgtype", "")
|
||||
|
||||
return self.get('msgtype', '')
|
||||
|
||||
@property
|
||||
def user_id(self) -> Optional[str]:
|
||||
"""
|
||||
@@ -44,7 +44,7 @@ class WecomCSEvent(dict):
|
||||
Returns:
|
||||
Optional[str]: 用户 ID。
|
||||
"""
|
||||
return self.get("external_userid")
|
||||
return self.get('external_userid')
|
||||
|
||||
@property
|
||||
def receiver_id(self) -> Optional[str]:
|
||||
@@ -54,8 +54,8 @@ class WecomCSEvent(dict):
|
||||
Returns:
|
||||
Optional[str]: 接收者 ID。
|
||||
"""
|
||||
return self.get("open_kfid","")
|
||||
|
||||
return self.get('open_kfid', '')
|
||||
|
||||
@property
|
||||
def picurl(self) -> Optional[str]:
|
||||
"""
|
||||
@@ -65,7 +65,7 @@ class WecomCSEvent(dict):
|
||||
Optional[str]: 图片 URL。
|
||||
"""
|
||||
|
||||
return self.get("picurl","")
|
||||
return self.get('picurl', '')
|
||||
|
||||
@property
|
||||
def message_id(self) -> Optional[str]:
|
||||
@@ -75,7 +75,7 @@ class WecomCSEvent(dict):
|
||||
Returns:
|
||||
Optional[str]: 消息 ID。
|
||||
"""
|
||||
return self.get("msgid")
|
||||
return self.get('msgid')
|
||||
|
||||
@property
|
||||
def message(self) -> Optional[str]:
|
||||
@@ -85,12 +85,11 @@ class WecomCSEvent(dict):
|
||||
Returns:
|
||||
Optional[str]: 消息内容。
|
||||
"""
|
||||
if self.get("msgtype") == 'text':
|
||||
return self.get("text").get("content","")
|
||||
if self.get('msgtype') == 'text':
|
||||
return self.get('text').get('content', '')
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@property
|
||||
def timestamp(self) -> Optional[int]:
|
||||
"""
|
||||
@@ -99,8 +98,7 @@ class WecomCSEvent(dict):
|
||||
Returns:
|
||||
Optional[int]: 时间戳。
|
||||
"""
|
||||
return self.get("send_time")
|
||||
|
||||
return self.get('send_time')
|
||||
|
||||
def __getattr__(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
@@ -131,4 +129,4 @@ class WecomCSEvent(dict):
|
||||
Returns:
|
||||
str: 字符串表示。
|
||||
"""
|
||||
return f"<WecomEvent {super().__repr__()}>"
|
||||
return f'<WecomEvent {super().__repr__()}>'
|
||||
|
||||
@@ -65,9 +65,7 @@ class RouterGroup(abc.ABC):
|
||||
async def handler_error(*args, **kwargs):
|
||||
if auth_type == AuthType.USER_TOKEN:
|
||||
# 从Authorization头中获取token
|
||||
token = quart.request.headers.get('Authorization', '').replace(
|
||||
'Bearer ', ''
|
||||
)
|
||||
token = quart.request.headers.get('Authorization', '').replace('Bearer ', '')
|
||||
|
||||
if not token:
|
||||
return self.http_status(401, -1, '未提供有效的用户令牌')
|
||||
|
||||
@@ -14,10 +14,8 @@ class LogsRouterGroup(group.RouterGroup):
|
||||
start_page_number = int(quart.request.args.get('start_page_number', 0))
|
||||
start_offset = int(quart.request.args.get('start_offset', 0))
|
||||
|
||||
logs_str, end_page_number, end_offset = (
|
||||
self.ap.log_cache.get_log_by_pointer(
|
||||
start_page_number=start_page_number, start_offset=start_offset
|
||||
)
|
||||
logs_str, end_page_number, end_offset = self.ap.log_cache.get_log_by_pointer(
|
||||
start_page_number=start_page_number, start_offset=start_offset
|
||||
)
|
||||
|
||||
return self.success(
|
||||
|
||||
@@ -11,23 +11,17 @@ class PipelinesRouterGroup(group.RouterGroup):
|
||||
@self.route('', methods=['GET', 'POST'])
|
||||
async def _() -> str:
|
||||
if quart.request.method == 'GET':
|
||||
return self.success(
|
||||
data={'pipelines': await self.ap.pipeline_service.get_pipelines()}
|
||||
)
|
||||
return self.success(data={'pipelines': await self.ap.pipeline_service.get_pipelines()})
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
|
||||
pipeline_uuid = await self.ap.pipeline_service.create_pipeline(
|
||||
json_data
|
||||
)
|
||||
pipeline_uuid = await self.ap.pipeline_service.create_pipeline(json_data)
|
||||
|
||||
return self.success(data={'uuid': pipeline_uuid})
|
||||
|
||||
@self.route('/_/metadata', methods=['GET'])
|
||||
async def _() -> str:
|
||||
return self.success(
|
||||
data={'configs': await self.ap.pipeline_service.get_pipeline_metadata()}
|
||||
)
|
||||
return self.success(data={'configs': await self.ap.pipeline_service.get_pipeline_metadata()})
|
||||
|
||||
@self.route('/<pipeline_uuid>', methods=['GET', 'PUT', 'DELETE'])
|
||||
async def _(pipeline_uuid: str) -> str:
|
||||
|
||||
@@ -8,30 +8,20 @@ class AdaptersRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['GET'])
|
||||
async def _() -> str:
|
||||
return self.success(
|
||||
data={'adapters': self.ap.platform_mgr.get_available_adapters_info()}
|
||||
)
|
||||
return self.success(data={'adapters': self.ap.platform_mgr.get_available_adapters_info()})
|
||||
|
||||
@self.route('/<adapter_name>', methods=['GET'])
|
||||
async def _(adapter_name: str) -> str:
|
||||
adapter_info = self.ap.platform_mgr.get_available_adapter_info_by_name(
|
||||
adapter_name
|
||||
)
|
||||
adapter_info = self.ap.platform_mgr.get_available_adapter_info_by_name(adapter_name)
|
||||
|
||||
if adapter_info is None:
|
||||
return self.http_status(404, -1, 'adapter not found')
|
||||
|
||||
return self.success(data={'adapter': adapter_info})
|
||||
|
||||
@self.route(
|
||||
'/<adapter_name>/icon', methods=['GET'], auth_type=group.AuthType.NONE
|
||||
)
|
||||
@self.route('/<adapter_name>/icon', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
async def _(adapter_name: str) -> quart.Response:
|
||||
adapter_manifest = (
|
||||
self.ap.platform_mgr.get_available_adapter_manifest_by_name(
|
||||
adapter_name
|
||||
)
|
||||
)
|
||||
adapter_manifest = self.ap.platform_mgr.get_available_adapter_manifest_by_name(adapter_name)
|
||||
|
||||
if adapter_manifest is None:
|
||||
return self.http_status(404, -1, 'adapter not found')
|
||||
|
||||
@@ -92,9 +92,7 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
await self.ap.plugin_mgr.reorder_plugins(data.get('plugins'))
|
||||
return self.success()
|
||||
|
||||
@self.route(
|
||||
'/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN
|
||||
)
|
||||
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
data = await quart.request.json
|
||||
|
||||
|
||||
@@ -9,9 +9,7 @@ class LLMModelsRouterGroup(group.RouterGroup):
|
||||
@self.route('', methods=['GET', 'POST'])
|
||||
async def _() -> str:
|
||||
if quart.request.method == 'GET':
|
||||
return self.success(
|
||||
data={'models': await self.ap.model_service.get_llm_models()}
|
||||
)
|
||||
return self.success(data={'models': await self.ap.model_service.get_llm_models()})
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
|
||||
|
||||
@@ -8,30 +8,20 @@ class RequestersRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['GET'])
|
||||
async def _() -> quart.Response:
|
||||
return self.success(
|
||||
data={'requesters': self.ap.model_mgr.get_available_requesters_info()}
|
||||
)
|
||||
return self.success(data={'requesters': self.ap.model_mgr.get_available_requesters_info()})
|
||||
|
||||
@self.route('/<requester_name>', methods=['GET'])
|
||||
async def _(requester_name: str) -> quart.Response:
|
||||
requester_info = self.ap.model_mgr.get_available_requester_info_by_name(
|
||||
requester_name
|
||||
)
|
||||
requester_info = self.ap.model_mgr.get_available_requester_info_by_name(requester_name)
|
||||
|
||||
if requester_info is None:
|
||||
return self.http_status(404, -1, 'requester not found')
|
||||
|
||||
return self.success(data={'requester': requester_info})
|
||||
|
||||
@self.route(
|
||||
'/<requester_name>/icon', methods=['GET'], auth_type=group.AuthType.NONE
|
||||
)
|
||||
@self.route('/<requester_name>/icon', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
async def _(requester_name: str) -> quart.Response:
|
||||
requester_manifest = (
|
||||
self.ap.model_mgr.get_available_requester_manifest_by_name(
|
||||
requester_name
|
||||
)
|
||||
)
|
||||
requester_manifest = self.ap.model_mgr.get_available_requester_manifest_by_name(requester_name)
|
||||
|
||||
if requester_manifest is None:
|
||||
return self.http_status(404, -1, 'requester not found')
|
||||
|
||||
@@ -8,9 +8,7 @@ class StatsRouterGroup(group.RouterGroup):
|
||||
async def _() -> str:
|
||||
conv_count = 0
|
||||
for session in self.ap.sess_mgr.session_list:
|
||||
conv_count += len(
|
||||
session.conversations if session.conversations is not None else []
|
||||
)
|
||||
conv_count += len(session.conversations if session.conversations is not None else [])
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
|
||||
@@ -13,9 +13,7 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
data={
|
||||
'version': constants.semantic_version,
|
||||
'debug': constants.debug_mode,
|
||||
'enabled_platform_count': len(
|
||||
self.ap.platform_mgr.get_running_adapters()
|
||||
),
|
||||
'enabled_platform_count': len(self.ap.platform_mgr.get_running_adapters()),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -28,9 +26,7 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
|
||||
return self.success(data=self.ap.task_mgr.get_tasks_dict(task_type))
|
||||
|
||||
@self.route(
|
||||
'/tasks/<task_id>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN
|
||||
)
|
||||
@self.route('/tasks/<task_id>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _(task_id: str) -> str:
|
||||
task = self.ap.task_mgr.get_task_by_id(int(task_id))
|
||||
|
||||
@@ -48,9 +44,7 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
await self.ap.reload(scope=scope)
|
||||
return self.success()
|
||||
|
||||
@self.route(
|
||||
'/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN
|
||||
)
|
||||
@self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
if not constants.debug_mode:
|
||||
return self.http_status(403, 403, 'Forbidden')
|
||||
|
||||
@@ -10,9 +10,7 @@ class UserRouterGroup(group.RouterGroup):
|
||||
@self.route('/init', methods=['GET', 'POST'], auth_type=group.AuthType.NONE)
|
||||
async def _() -> str:
|
||||
if quart.request.method == 'GET':
|
||||
return self.success(
|
||||
data={'initialized': await self.ap.user_service.is_initialized()}
|
||||
)
|
||||
return self.success(data={'initialized': await self.ap.user_service.is_initialized()})
|
||||
|
||||
if await self.ap.user_service.is_initialized():
|
||||
return self.fail(1, '系统已初始化')
|
||||
@@ -31,17 +29,13 @@ class UserRouterGroup(group.RouterGroup):
|
||||
json_data = await quart.request.json
|
||||
|
||||
try:
|
||||
token = await self.ap.user_service.authenticate(
|
||||
json_data['user'], json_data['password']
|
||||
)
|
||||
token = await self.ap.user_service.authenticate(json_data['user'], json_data['password'])
|
||||
except argon2.exceptions.VerifyMismatchError:
|
||||
return self.fail(1, '用户名或密码错误')
|
||||
|
||||
return self.success(data={'token': token})
|
||||
|
||||
@self.route(
|
||||
'/check-token', methods=['GET'], auth_type=group.AuthType.USER_TOKEN
|
||||
)
|
||||
@self.route('/check-token', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _(user_email: str) -> str:
|
||||
token = await self.ap.user_service.generate_jwt_token(user_email)
|
||||
|
||||
|
||||
@@ -70,15 +70,12 @@ class HTTPController:
|
||||
|
||||
@self.quart_app.route('/')
|
||||
async def index():
|
||||
return await quart.send_from_directory(
|
||||
frontend_path, 'index.html', mimetype='text/html'
|
||||
)
|
||||
return await quart.send_from_directory(frontend_path, 'index.html', mimetype='text/html')
|
||||
|
||||
@self.quart_app.route('/<path:path>')
|
||||
async def static_file(path: str):
|
||||
if not (
|
||||
os.path.exists(os.path.join(frontend_path, path))
|
||||
and os.path.isfile(os.path.join(frontend_path, path))
|
||||
os.path.exists(os.path.join(frontend_path, path)) and os.path.isfile(os.path.join(frontend_path, path))
|
||||
):
|
||||
if os.path.exists(os.path.join(frontend_path, path + '.html')):
|
||||
path += '.html'
|
||||
@@ -110,6 +107,4 @@ class HTTPController:
|
||||
elif path.endswith('.txt'):
|
||||
mimetype = 'text/plain'
|
||||
|
||||
return await quart.send_from_directory(
|
||||
frontend_path, path, mimetype=mimetype
|
||||
)
|
||||
return await quart.send_from_directory(frontend_path, path, mimetype=mimetype)
|
||||
|
||||
@@ -18,23 +18,16 @@ class BotService:
|
||||
|
||||
async def get_bots(self) -> list[dict]:
|
||||
"""获取所有机器人"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_bot.Bot)
|
||||
)
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot))
|
||||
|
||||
bots = result.all()
|
||||
|
||||
return [
|
||||
self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot)
|
||||
for bot in bots
|
||||
]
|
||||
return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot) for bot in bots]
|
||||
|
||||
async def get_bot(self, bot_uuid: str) -> dict | None:
|
||||
"""获取机器人"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_bot.Bot).where(
|
||||
persistence_bot.Bot.uuid == bot_uuid
|
||||
)
|
||||
sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
|
||||
)
|
||||
|
||||
bot = result.first()
|
||||
@@ -60,9 +53,7 @@ class BotService:
|
||||
bot_data['use_pipeline_uuid'] = pipeline.uuid
|
||||
bot_data['use_pipeline_name'] = pipeline.name
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_bot.Bot).values(bot_data)
|
||||
)
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_bot.Bot).values(bot_data))
|
||||
|
||||
bot = await self.get_bot(bot_data['uuid'])
|
||||
|
||||
@@ -79,8 +70,7 @@ class BotService:
|
||||
if 'use_pipeline_uuid' in bot_data:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
|
||||
persistence_pipeline.LegacyPipeline.uuid
|
||||
== bot_data['use_pipeline_uuid']
|
||||
persistence_pipeline.LegacyPipeline.uuid == bot_data['use_pipeline_uuid']
|
||||
)
|
||||
)
|
||||
pipeline = result.first()
|
||||
@@ -90,9 +80,7 @@ class BotService:
|
||||
raise Exception('Pipeline not found')
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_bot.Bot)
|
||||
.values(bot_data)
|
||||
.where(persistence_bot.Bot.uuid == bot_uuid)
|
||||
sqlalchemy.update(persistence_bot.Bot).values(bot_data).where(persistence_bot.Bot.uuid == bot_uuid)
|
||||
)
|
||||
await self.ap.platform_mgr.remove_bot(bot_uuid)
|
||||
|
||||
@@ -108,7 +96,5 @@ class BotService:
|
||||
"""删除机器人"""
|
||||
await self.ap.platform_mgr.remove_bot(bot_uuid)
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_bot.Bot).where(
|
||||
persistence_bot.Bot.uuid == bot_uuid
|
||||
)
|
||||
sqlalchemy.delete(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
|
||||
)
|
||||
|
||||
@@ -15,22 +15,15 @@ class ModelsService:
|
||||
self.ap = ap
|
||||
|
||||
async def get_llm_models(self) -> list[dict]:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.LLMModel)
|
||||
)
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
|
||||
|
||||
models = result.all()
|
||||
return [
|
||||
self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
|
||||
for model in models
|
||||
]
|
||||
return [self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) for model in models]
|
||||
|
||||
async def create_llm_model(self, model_data: dict) -> str:
|
||||
model_data['uuid'] = str(uuid.uuid4())
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_model.LLMModel).values(**model_data)
|
||||
)
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_model.LLMModel).values(**model_data))
|
||||
|
||||
llm_model = await self.get_llm_model(model_data['uuid'])
|
||||
|
||||
@@ -53,9 +46,7 @@ class ModelsService:
|
||||
|
||||
async def get_llm_model(self, model_uuid: str) -> dict | None:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_model.LLMModel).where(
|
||||
persistence_model.LLMModel.uuid == model_uuid
|
||||
)
|
||||
sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
|
||||
)
|
||||
|
||||
model = result.first()
|
||||
@@ -63,9 +54,7 @@ class ModelsService:
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
return self.ap.persistence_mgr.serialize_model(
|
||||
persistence_model.LLMModel, model
|
||||
)
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
|
||||
|
||||
async def update_llm_model(self, model_uuid: str, model_data: dict) -> None:
|
||||
if 'uuid' in model_data:
|
||||
@@ -85,9 +74,7 @@ class ModelsService:
|
||||
|
||||
async def delete_llm_model(self, model_uuid: str) -> None:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_model.LLMModel).where(
|
||||
persistence_model.LLMModel.uuid == model_uuid
|
||||
)
|
||||
sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
|
||||
)
|
||||
|
||||
await self.ap.model_mgr.remove_llm_model(model_uuid)
|
||||
|
||||
@@ -39,15 +39,11 @@ class PipelineService:
|
||||
]
|
||||
|
||||
async def get_pipelines(self) -> list[dict]:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline)
|
||||
)
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
|
||||
|
||||
pipelines = result.all()
|
||||
return [
|
||||
self.ap.persistence_mgr.serialize_model(
|
||||
persistence_pipeline.LegacyPipeline, pipeline
|
||||
)
|
||||
self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
||||
for pipeline in pipelines
|
||||
]
|
||||
|
||||
@@ -63,23 +59,17 @@ class PipelineService:
|
||||
if pipeline is None:
|
||||
return None
|
||||
|
||||
return self.ap.persistence_mgr.serialize_model(
|
||||
persistence_pipeline.LegacyPipeline, pipeline
|
||||
)
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
||||
|
||||
async def create_pipeline(self, pipeline_data: dict, default: bool = False) -> str:
|
||||
pipeline_data['uuid'] = str(uuid.uuid4())
|
||||
pipeline_data['for_version'] = self.ap.ver_mgr.get_current_version()
|
||||
pipeline_data['stages'] = default_stage_order.copy()
|
||||
pipeline_data['is_default'] = default
|
||||
pipeline_data['config'] = json.load(
|
||||
open('templates/default-pipeline-config.json', 'r', encoding='utf-8')
|
||||
)
|
||||
pipeline_data['config'] = json.load(open('templates/default-pipeline-config.json', 'r', encoding='utf-8'))
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(
|
||||
**pipeline_data
|
||||
)
|
||||
sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(**pipeline_data)
|
||||
)
|
||||
|
||||
pipeline = await self.get_pipeline(pipeline_data['uuid'])
|
||||
|
||||
@@ -17,9 +17,7 @@ class UserService:
|
||||
self.ap = ap
|
||||
|
||||
async def is_initialized(self) -> bool:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(user.User).limit(1)
|
||||
)
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(user.User).limit(1))
|
||||
|
||||
result_list = result.all()
|
||||
return result_list is not None and len(result_list) > 0
|
||||
@@ -30,9 +28,7 @@ class UserService:
|
||||
hashed_password = ph.hash(password)
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.insert(user.User).values(
|
||||
user=user_email, password=hashed_password
|
||||
)
|
||||
sqlalchemy.insert(user.User).values(user=user_email, password=hashed_password)
|
||||
)
|
||||
|
||||
async def get_user_by_email(self, user_email: str) -> user.User | None:
|
||||
@@ -41,9 +37,7 @@ class UserService:
|
||||
)
|
||||
|
||||
result_list = result.all()
|
||||
return (
|
||||
result_list[0] if result_list is not None and len(result_list) > 0 else None
|
||||
)
|
||||
return result_list[0] if result_list is not None and len(result_list) > 0 else None
|
||||
|
||||
async def authenticate(self, user_email: str, password: str) -> str | None:
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
|
||||
@@ -40,18 +40,14 @@ class CommandManager:
|
||||
# 应用命令权限配置
|
||||
for cls in operator.preregistered_operators:
|
||||
if cls.path in self.ap.instance_config.data['command']['privilege']:
|
||||
cls.lowest_privilege = self.ap.instance_config.data['command'][
|
||||
'privilege'
|
||||
][cls.path]
|
||||
cls.lowest_privilege = self.ap.instance_config.data['command']['privilege'][cls.path]
|
||||
|
||||
# 实例化所有类
|
||||
self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators]
|
||||
|
||||
# 设置所有类的子节点
|
||||
for cmd in self.cmd_list:
|
||||
cmd.children = [
|
||||
child for child in self.cmd_list if child.parent_class == cmd.__class__
|
||||
]
|
||||
cmd.children = [child for child in self.cmd_list if child.parent_class == cmd.__class__]
|
||||
|
||||
# 初始化所有类
|
||||
for cmd in self.cmd_list:
|
||||
@@ -68,10 +64,7 @@ class CommandManager:
|
||||
found = False
|
||||
if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
|
||||
for oper in operator_list:
|
||||
if (
|
||||
context.crt_params[0] == oper.name
|
||||
or context.crt_params[0] in oper.alias
|
||||
) and (
|
||||
if (context.crt_params[0] == oper.name or context.crt_params[0] in oper.alias) and (
|
||||
oper.parent_class is None or oper.parent_class == operator.__class__
|
||||
):
|
||||
found = True
|
||||
@@ -85,14 +78,10 @@ class CommandManager:
|
||||
|
||||
if not found: # 如果下一个参数未在此节点的子节点中找到,则执行此节点或者报错
|
||||
if operator is None:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandNotFoundError(context.crt_params[0])
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandNotFoundError(context.crt_params[0]))
|
||||
else:
|
||||
if operator.lowest_privilege > context.privilege:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandPrivilegeError(operator.name)
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandPrivilegeError(operator.name))
|
||||
else:
|
||||
async for ret in operator.execute(context):
|
||||
yield ret
|
||||
@@ -107,10 +96,7 @@ class CommandManager:
|
||||
|
||||
privilege = 1
|
||||
|
||||
if (
|
||||
f'{query.launcher_type.value}_{query.launcher_id}'
|
||||
in self.ap.instance_config.data['admins']
|
||||
):
|
||||
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
|
||||
privilege = 2
|
||||
|
||||
ctx = entities.ExecuteContext(
|
||||
|
||||
@@ -95,9 +95,7 @@ class CommandOperator(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
"""实现此方法以执行命令
|
||||
|
||||
支持多次yield以返回多个结果。
|
||||
|
||||
@@ -9,9 +9,7 @@ from .. import operator, entities, errors
|
||||
class CmdOperator(operator.CommandOperator):
|
||||
"""命令列表"""
|
||||
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
"""执行"""
|
||||
if len(context.crt_params) == 0:
|
||||
reply_str = '当前所有命令: \n\n'
|
||||
@@ -30,16 +28,12 @@ class CmdOperator(operator.CommandOperator):
|
||||
cmd = None
|
||||
|
||||
for _cmd in self.ap.cmd_mgr.cmd_list:
|
||||
if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (
|
||||
_cmd.parent_class is None
|
||||
):
|
||||
if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (_cmd.parent_class is None):
|
||||
cmd = _cmd
|
||||
break
|
||||
|
||||
if cmd is None:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandNotFoundError(cmd_name)
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name))
|
||||
else:
|
||||
reply_str = f'{cmd.name}: {cmd.help}\n\n'
|
||||
reply_str += f'使用方法: \n{cmd.usage}'
|
||||
|
||||
@@ -5,55 +5,38 @@ import typing
|
||||
from .. import operator, entities, errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all'
|
||||
)
|
||||
@operator.operator_class(name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all')
|
||||
class DelOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
if context.session.conversations:
|
||||
delete_index = 0
|
||||
if len(context.crt_params) > 0:
|
||||
try:
|
||||
delete_index = int(context.crt_params[0])
|
||||
except Exception:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandOperationError('索引必须是整数')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数'))
|
||||
return
|
||||
|
||||
if delete_index < 0 or delete_index >= len(context.session.conversations):
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandOperationError('索引超出范围')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围'))
|
||||
return
|
||||
|
||||
# 倒序
|
||||
to_delete_index = len(context.session.conversations) - 1 - delete_index
|
||||
|
||||
if (
|
||||
context.session.conversations[to_delete_index]
|
||||
== context.session.using_conversation
|
||||
):
|
||||
if context.session.conversations[to_delete_index] == context.session.using_conversation:
|
||||
context.session.using_conversation = None
|
||||
|
||||
del context.session.conversations[to_delete_index]
|
||||
|
||||
yield entities.CommandReturn(text=f'已删除对话: {delete_index}')
|
||||
else:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandOperationError('当前没有对话')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name='all', help='删除此会话的所有历史记录', parent_class=DelOperator
|
||||
)
|
||||
@operator.operator_class(name='all', help='删除此会话的所有历史记录', parent_class=DelOperator)
|
||||
class DelAllOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
context.session.conversations = []
|
||||
context.session.using_conversation = None
|
||||
|
||||
|
||||
@@ -6,9 +6,7 @@ from .. import operator, entities
|
||||
|
||||
@operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func')
|
||||
class FuncOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> AsyncGenerator[entities.CommandReturn, None]:
|
||||
reply_str = '当前已启用的内容函数: \n\n'
|
||||
|
||||
index = 1
|
||||
|
||||
@@ -7,9 +7,7 @@ from .. import operator, entities
|
||||
|
||||
@operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>')
|
||||
class HelpOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接:https://langbot.app'
|
||||
|
||||
help += '\n发送命令 !cmd 可查看命令列表'
|
||||
|
||||
@@ -8,36 +8,21 @@ from .. import operator, entities, errors
|
||||
|
||||
@operator.operator_class(name='last', help='切换到前一个对话', usage='!last')
|
||||
class LastOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
if context.session.conversations:
|
||||
# 找到当前会话的上一个会话
|
||||
for index in range(len(context.session.conversations) - 1, -1, -1):
|
||||
if (
|
||||
context.session.conversations[index]
|
||||
== context.session.using_conversation
|
||||
):
|
||||
if context.session.conversations[index] == context.session.using_conversation:
|
||||
if index == 0:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandOperationError('已经是第一个对话了')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了'))
|
||||
return
|
||||
else:
|
||||
context.session.using_conversation = (
|
||||
context.session.conversations[index - 1]
|
||||
)
|
||||
time_str = (
|
||||
context.session.using_conversation.create_time.strftime(
|
||||
'%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
)
|
||||
context.session.using_conversation = context.session.conversations[index - 1]
|
||||
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
yield entities.CommandReturn(
|
||||
text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}'
|
||||
)
|
||||
return
|
||||
else:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandOperationError('当前没有对话')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
||||
|
||||
@@ -5,22 +5,16 @@ import typing
|
||||
from .. import operator, entities, errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>'
|
||||
)
|
||||
@operator.operator_class(name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>')
|
||||
class ListOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
page = 0
|
||||
|
||||
if len(context.crt_params) > 0:
|
||||
try:
|
||||
page = int(context.crt_params[0] - 1)
|
||||
except Exception:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandOperationError('页码应为整数')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数'))
|
||||
return
|
||||
|
||||
record_per_page = 10
|
||||
@@ -38,7 +32,9 @@ class ListOperator(operator.CommandOperator):
|
||||
using_conv_index = index
|
||||
|
||||
if index >= page * record_per_page and index < (page + 1) * record_per_page:
|
||||
content += f'{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else "无内容"}\n'
|
||||
content += (
|
||||
f'{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else "无内容"}\n'
|
||||
)
|
||||
index += 1
|
||||
|
||||
if content == '':
|
||||
|
||||
@@ -14,9 +14,7 @@ from .. import operator, entities, errors
|
||||
class ModelOperator(operator.CommandOperator):
|
||||
"""Model命令"""
|
||||
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
content = '模型列表:\n'
|
||||
|
||||
model_list = self.ap.model_mgr.model_list
|
||||
@@ -31,15 +29,11 @@ class ModelOperator(operator.CommandOperator):
|
||||
yield entities.CommandReturn(text=content.strip())
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name='show', help='显示模型详情', privilege=2, parent_class=ModelOperator
|
||||
)
|
||||
@operator.operator_class(name='show', help='显示模型详情', privilege=2, parent_class=ModelOperator)
|
||||
class ModelShowOperator(operator.CommandOperator):
|
||||
"""Model Show命令"""
|
||||
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
model_name = context.crt_params[0]
|
||||
|
||||
model = None
|
||||
@@ -49,9 +43,7 @@ class ModelShowOperator(operator.CommandOperator):
|
||||
break
|
||||
|
||||
if model is None:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError(f'未找到模型 {model_name}')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandError(f'未找到模型 {model_name}'))
|
||||
else:
|
||||
content = '模型详情\n'
|
||||
content += f'名称: {model.name}\n'
|
||||
@@ -65,15 +57,11 @@ class ModelShowOperator(operator.CommandOperator):
|
||||
yield entities.CommandReturn(text=content.strip())
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name='set', help='设置默认使用模型', privilege=2, parent_class=ModelOperator
|
||||
)
|
||||
@operator.operator_class(name='set', help='设置默认使用模型', privilege=2, parent_class=ModelOperator)
|
||||
class ModelSetOperator(operator.CommandOperator):
|
||||
"""Model Set命令"""
|
||||
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
model_name = context.crt_params[0]
|
||||
|
||||
model = None
|
||||
@@ -83,12 +71,8 @@ class ModelSetOperator(operator.CommandOperator):
|
||||
break
|
||||
|
||||
if model is None:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError(f'未找到模型 {model_name}')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandError(f'未找到模型 {model_name}'))
|
||||
else:
|
||||
self.ap.provider_cfg.data['model'] = model_name
|
||||
await self.ap.provider_cfg.dump_config()
|
||||
yield entities.CommandReturn(
|
||||
text=f'已设置当前使用模型为 {model_name},重置会话以生效'
|
||||
)
|
||||
yield entities.CommandReturn(text=f'已设置当前使用模型为 {model_name},重置会话以生效')
|
||||
|
||||
@@ -7,36 +7,21 @@ from .. import operator, entities, errors
|
||||
|
||||
@operator.operator_class(name='next', help='切换到后一个对话', usage='!next')
|
||||
class NextOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
if context.session.conversations:
|
||||
# 找到当前会话的下一个会话
|
||||
for index in range(len(context.session.conversations)):
|
||||
if (
|
||||
context.session.conversations[index]
|
||||
== context.session.using_conversation
|
||||
):
|
||||
if context.session.conversations[index] == context.session.using_conversation:
|
||||
if index == len(context.session.conversations) - 1:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandOperationError('已经是最后一个对话了')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了'))
|
||||
return
|
||||
else:
|
||||
context.session.using_conversation = (
|
||||
context.session.conversations[index + 1]
|
||||
)
|
||||
time_str = (
|
||||
context.session.using_conversation.create_time.strftime(
|
||||
'%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
)
|
||||
context.session.using_conversation = context.session.conversations[index + 1]
|
||||
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
yield entities.CommandReturn(
|
||||
text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}'
|
||||
)
|
||||
return
|
||||
else:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandOperationError('当前没有对话')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
||||
|
||||
@@ -13,9 +13,7 @@ from .. import operator, entities, errors
|
||||
usage='!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>',
|
||||
)
|
||||
class OllamaOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
try:
|
||||
content: str = '模型列表:\n'
|
||||
model_list: list = ollama.list().get('models', [])
|
||||
@@ -25,9 +23,7 @@ class OllamaOperator(operator.CommandOperator):
|
||||
content += f'大小: {bytes_to_mb(model["size"])}MB\n\n'
|
||||
yield entities.CommandReturn(text=f'{content.strip()}')
|
||||
except ollama.ResponseError:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常'))
|
||||
|
||||
|
||||
def bytes_to_mb(num_bytes):
|
||||
@@ -35,13 +31,9 @@ def bytes_to_mb(num_bytes):
|
||||
return format(mb, '.2f')
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name='show', help='ollama模型详情', privilege=2, parent_class=OllamaOperator
|
||||
)
|
||||
@operator.operator_class(name='show', help='ollama模型详情', privilege=2, parent_class=OllamaOperator)
|
||||
class OllamaShowOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
content: str = '模型详情:\n'
|
||||
try:
|
||||
show: dict = ollama.show(model=context.crt_params[0])
|
||||
@@ -60,27 +52,19 @@ class OllamaShowOperator(operator.CommandOperator):
|
||||
content += json.dumps(show, indent=4)
|
||||
yield entities.CommandReturn(text=content.strip())
|
||||
except ollama.ResponseError:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError('无法获取模型详情,请确认 Ollama 服务正常')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandError('无法获取模型详情,请确认 Ollama 服务正常'))
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name='pull', help='ollama模型拉取', privilege=2, parent_class=OllamaOperator
|
||||
)
|
||||
@operator.operator_class(name='pull', help='ollama模型拉取', privilege=2, parent_class=OllamaOperator)
|
||||
class OllamaPullOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
try:
|
||||
model_list: list = ollama.list().get('models', [])
|
||||
if context.crt_params[0] in [model['name'] for model in model_list]:
|
||||
yield entities.CommandReturn(text='模型已存在')
|
||||
return
|
||||
except ollama.ResponseError:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常'))
|
||||
return
|
||||
|
||||
on_progress: bool = False
|
||||
@@ -108,13 +92,9 @@ class OllamaPullOperator(operator.CommandOperator):
|
||||
yield entities.CommandReturn(text=f'拉取失败: {e.error}')
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name='del', help='ollama模型删除', privilege=2, parent_class=OllamaOperator
|
||||
)
|
||||
@operator.operator_class(name='del', help='ollama模型删除', privilege=2, parent_class=OllamaOperator)
|
||||
class OllamaDelOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
try:
|
||||
ret: str = ollama.delete(model=context.crt_params[0])['status']
|
||||
except ollama.ResponseError as e:
|
||||
|
||||
@@ -11,9 +11,7 @@ from .. import operator, entities, errors
|
||||
usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>',
|
||||
)
|
||||
class PluginOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
plugin_list = self.ap.plugin_mgr.plugins()
|
||||
reply_str = '所有插件({}):\n'.format(len(plugin_list))
|
||||
idx = 0
|
||||
@@ -32,17 +30,11 @@ class PluginOperator(operator.CommandOperator):
|
||||
yield entities.CommandReturn(text=reply_str)
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name='get', help='安装插件', privilege=2, parent_class=PluginOperator
|
||||
)
|
||||
@operator.operator_class(name='get', help='安装插件', privilege=2, parent_class=PluginOperator)
|
||||
class PluginGetOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.ParamNotEnoughError('请提供插件仓库地址')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址'))
|
||||
else:
|
||||
repo = context.crt_params[0]
|
||||
|
||||
@@ -53,22 +45,14 @@ class PluginGetOperator(operator.CommandOperator):
|
||||
yield entities.CommandReturn(text='插件安装成功,请重启程序以加载插件')
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError('插件安装失败: ' + str(e))
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件安装失败: ' + str(e)))
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name='update', help='更新插件', privilege=2, parent_class=PluginOperator
|
||||
)
|
||||
@operator.operator_class(name='update', help='更新插件', privilege=2, parent_class=PluginOperator)
|
||||
class PluginUpdateOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.ParamNotEnoughError('请提供插件名称')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
||||
else:
|
||||
plugin_name = context.crt_params[0]
|
||||
|
||||
@@ -78,27 +62,17 @@ class PluginUpdateOperator(operator.CommandOperator):
|
||||
if plugin_container is not None:
|
||||
yield entities.CommandReturn(text='正在更新插件...')
|
||||
await self.ap.plugin_mgr.update_plugin(plugin_name)
|
||||
yield entities.CommandReturn(
|
||||
text='插件更新成功,请重启程序以加载插件'
|
||||
)
|
||||
yield entities.CommandReturn(text='插件更新成功,请重启程序以加载插件')
|
||||
else:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError('插件更新失败: 未找到插件')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: 未找到插件'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError('插件更新失败: ' + str(e))
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator
|
||||
)
|
||||
@operator.operator_class(name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator)
|
||||
class PluginUpdateAllOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
try:
|
||||
plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()]
|
||||
|
||||
@@ -111,32 +85,20 @@ class PluginUpdateAllOperator(operator.CommandOperator):
|
||||
updated.append(plugin_name)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError('插件更新失败: ' + str(e))
|
||||
)
|
||||
yield entities.CommandReturn(
|
||||
text='已更新插件: {}'.format(', '.join(updated))
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
|
||||
yield entities.CommandReturn(text='已更新插件: {}'.format(', '.join(updated)))
|
||||
else:
|
||||
yield entities.CommandReturn(text='没有可更新的插件')
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError('插件更新失败: ' + str(e))
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name='del', help='删除插件', privilege=2, parent_class=PluginOperator
|
||||
)
|
||||
@operator.operator_class(name='del', help='删除插件', privilege=2, parent_class=PluginOperator)
|
||||
class PluginDelOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.ParamNotEnoughError('请提供插件名称')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
||||
else:
|
||||
plugin_name = context.crt_params[0]
|
||||
|
||||
@@ -146,79 +108,49 @@ class PluginDelOperator(operator.CommandOperator):
|
||||
if plugin_container is not None:
|
||||
yield entities.CommandReturn(text='正在删除插件...')
|
||||
await self.ap.plugin_mgr.uninstall_plugin(plugin_name)
|
||||
yield entities.CommandReturn(
|
||||
text='插件删除成功,请重启程序以加载插件'
|
||||
)
|
||||
yield entities.CommandReturn(text='插件删除成功,请重启程序以加载插件')
|
||||
else:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError('插件删除失败: 未找到插件')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件删除失败: 未找到插件'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError('插件删除失败: ' + str(e))
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件删除失败: ' + str(e)))
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name='on', help='启用插件', privilege=2, parent_class=PluginOperator
|
||||
)
|
||||
@operator.operator_class(name='on', help='启用插件', privilege=2, parent_class=PluginOperator)
|
||||
class PluginEnableOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.ParamNotEnoughError('请提供插件名称')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
||||
else:
|
||||
plugin_name = context.crt_params[0]
|
||||
|
||||
try:
|
||||
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True):
|
||||
yield entities.CommandReturn(
|
||||
text='已启用插件: {}'.format(plugin_name)
|
||||
)
|
||||
yield entities.CommandReturn(text='已启用插件: {}'.format(plugin_name))
|
||||
else:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError(
|
||||
'插件状态修改失败: 未找到插件 {}'.format(plugin_name)
|
||||
)
|
||||
error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError('插件状态修改失败: ' + str(e))
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e)))
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name='off', help='禁用插件', privilege=2, parent_class=PluginOperator
|
||||
)
|
||||
@operator.operator_class(name='off', help='禁用插件', privilege=2, parent_class=PluginOperator)
|
||||
class PluginDisableOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.ParamNotEnoughError('请提供插件名称')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
||||
else:
|
||||
plugin_name = context.crt_params[0]
|
||||
|
||||
try:
|
||||
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False):
|
||||
yield entities.CommandReturn(
|
||||
text='已禁用插件: {}'.format(plugin_name)
|
||||
)
|
||||
yield entities.CommandReturn(text='已禁用插件: {}'.format(plugin_name))
|
||||
else:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError(
|
||||
'插件状态修改失败: 未找到插件 {}'.format(plugin_name)
|
||||
)
|
||||
error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError('插件状态修改失败: ' + str(e))
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e)))
|
||||
|
||||
@@ -7,14 +7,10 @@ from .. import operator, entities, errors
|
||||
|
||||
@operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt')
|
||||
class PromptOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
"""执行"""
|
||||
if context.session.using_conversation is None:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandOperationError('当前没有对话')
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
||||
else:
|
||||
reply_str = '当前对话所有内容:\n\n'
|
||||
|
||||
|
||||
@@ -5,13 +5,9 @@ import typing
|
||||
from .. import operator, entities, errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name='resend', help='重发当前会话的最后一条消息', usage='!resend'
|
||||
)
|
||||
@operator.operator_class(name='resend', help='重发当前会话的最后一条消息', usage='!resend')
|
||||
class ResendOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
# 回滚到最后一条用户message前
|
||||
if context.session.using_conversation is None:
|
||||
yield entities.CommandReturn(error=errors.CommandError('当前没有对话'))
|
||||
|
||||
@@ -7,9 +7,7 @@ from .. import operator, entities
|
||||
|
||||
@operator.operator_class(name='reset', help='重置当前会话', usage='!reset')
|
||||
class ResetOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
"""执行"""
|
||||
context.session.using_conversation = None
|
||||
|
||||
|
||||
@@ -8,9 +8,7 @@ from .. import operator, entities, errors
|
||||
|
||||
@operator.operator_class(name='update', help='更新程序', usage='!update', privilege=2)
|
||||
class UpdateCommand(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
try:
|
||||
yield entities.CommandReturn(text='正在进行更新...')
|
||||
if await self.ap.ver_mgr.update_all():
|
||||
@@ -19,6 +17,4 @@ class UpdateCommand(operator.CommandOperator):
|
||||
yield entities.CommandReturn(text='当前已是最新版本')
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError('更新失败: ' + str(e))
|
||||
)
|
||||
yield entities.CommandReturn(error=errors.CommandError('更新失败: ' + str(e)))
|
||||
|
||||
@@ -7,9 +7,7 @@ from .. import operator, entities
|
||||
|
||||
@operator.operator_class(name='version', help='显示版本信息', usage='!version')
|
||||
class VersionCommand(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}'
|
||||
|
||||
try:
|
||||
|
||||
@@ -41,9 +41,7 @@ class ConfigManager:
|
||||
self.file.save_sync(self.data)
|
||||
|
||||
|
||||
async def load_python_module_config(
|
||||
config_name: str, template_name: str, completion: bool = True
|
||||
) -> ConfigManager:
|
||||
async def load_python_module_config(config_name: str, template_name: str, completion: bool = True) -> ConfigManager:
|
||||
"""加载Python模块配置文件
|
||||
|
||||
Args:
|
||||
|
||||
@@ -160,9 +160,7 @@ class Application:
|
||||
"""打印访问 webui 的提示"""
|
||||
|
||||
if not os.path.exists(os.path.join('.', 'web/out')):
|
||||
self.logger.warning(
|
||||
'WebUI 文件缺失,请根据文档获取:https://docs.langbot.app/webui/intro.html'
|
||||
)
|
||||
self.logger.warning('WebUI 文件缺失,请根据文档获取:https://docs.langbot.app/webui/intro.html')
|
||||
return
|
||||
|
||||
host_ip = '127.0.0.1'
|
||||
|
||||
@@ -26,9 +26,7 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging.
|
||||
if constants.debug_mode:
|
||||
level = logging.DEBUG
|
||||
|
||||
log_file_name = 'data/logs/langbot-%s.log' % time.strftime(
|
||||
'%Y-%m-%d', time.localtime()
|
||||
)
|
||||
log_file_name = 'data/logs/langbot-%s.log' % time.strftime('%Y-%m-%d', time.localtime())
|
||||
|
||||
qcg_logger = logging.getLogger('langbot')
|
||||
|
||||
@@ -43,9 +41,7 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging.
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
# stream_handler.setLevel(level)
|
||||
# stream_handler.setFormatter(color_formatter)
|
||||
stream_handler.stream = open(
|
||||
sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1
|
||||
)
|
||||
stream_handler.stream = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1)
|
||||
|
||||
log_handlers: list[logging.Handler] = [
|
||||
stream_handler,
|
||||
|
||||
@@ -87,8 +87,7 @@ class Query(pydantic.BaseModel):
|
||||
"""使用的函数,由前置处理器阶段设置"""
|
||||
|
||||
resp_messages: (
|
||||
typing.Optional[list[llm_entities.Message]]
|
||||
| typing.Optional[list[platform_message.MessageChain]]
|
||||
typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]]
|
||||
) = []
|
||||
"""由Process阶段生成的回复消息对象列表"""
|
||||
|
||||
@@ -130,13 +129,9 @@ class Conversation(pydantic.BaseModel):
|
||||
|
||||
messages: list[llm_entities.Message]
|
||||
|
||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(
|
||||
default_factory=datetime.datetime.now
|
||||
)
|
||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(
|
||||
default_factory=datetime.datetime.now
|
||||
)
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
use_llm_model: requester.RuntimeLLMModel
|
||||
|
||||
@@ -162,17 +157,11 @@ class Session(pydantic.BaseModel):
|
||||
|
||||
using_conversation: typing.Optional[Conversation] = None
|
||||
|
||||
conversations: typing.Optional[list[Conversation]] = pydantic.Field(
|
||||
default_factory=list
|
||||
)
|
||||
conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list)
|
||||
|
||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(
|
||||
default_factory=datetime.datetime.now
|
||||
)
|
||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(
|
||||
default_factory=datetime.datetime.now
|
||||
)
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
semaphore: typing.Optional[asyncio.Semaphore] = None
|
||||
"""当前会话的信号量,用于限制并发"""
|
||||
|
||||
@@ -11,16 +11,14 @@ class SensitiveWordMigration(migration.Migration):
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return os.path.exists(
|
||||
'data/config/sensitive-words.json'
|
||||
) and not os.path.exists('data/metadata/sensitive-words.json')
|
||||
return os.path.exists('data/config/sensitive-words.json') and not os.path.exists(
|
||||
'data/metadata/sensitive-words.json'
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
# 移动文件
|
||||
os.rename(
|
||||
'data/config/sensitive-words.json', 'data/metadata/sensitive-words.json'
|
||||
)
|
||||
os.rename('data/config/sensitive-words.json', 'data/metadata/sensitive-words.json')
|
||||
|
||||
# 重新加载配置
|
||||
await self.ap.sensitive_meta.load_config()
|
||||
|
||||
@@ -23,9 +23,7 @@ class OpenAIConfigMigration(migration.Migration):
|
||||
|
||||
self.ap.provider_cfg.data['keys']['openai'] = old_openai_config['api-keys']
|
||||
|
||||
self.ap.provider_cfg.data['model'] = old_openai_config[
|
||||
'chat-completions-params'
|
||||
]['model']
|
||||
self.ap.provider_cfg.data['model'] = old_openai_config['chat-completions-params']['model']
|
||||
|
||||
del old_openai_config['chat-completions-params']['model']
|
||||
|
||||
|
||||
@@ -15,8 +15,6 @@ class QCGCenterURLConfigMigration(migration.Migration):
|
||||
"""执行迁移"""
|
||||
|
||||
if 'qcg-center-url' not in self.ap.system_cfg.data:
|
||||
self.ap.system_cfg.data['qcg-center-url'] = (
|
||||
'https://api.qchatgpt.rockchin.top/api/v2'
|
||||
)
|
||||
self.ap.system_cfg.data['qcg-center-url'] = 'https://api.qchatgpt.rockchin.top/api/v2'
|
||||
|
||||
await self.ap.system_cfg.dump_config()
|
||||
|
||||
@@ -9,9 +9,7 @@ class AdFixwinConfigMigration(migration.Migration):
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return isinstance(
|
||||
self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default'], int
|
||||
)
|
||||
return isinstance(self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default'], int)
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
@@ -19,9 +17,7 @@ class AdFixwinConfigMigration(migration.Migration):
|
||||
for session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
|
||||
temp_dict = {
|
||||
'window-size': 60,
|
||||
'limit': self.ap.pipeline_cfg.data['rate-limit']['fixwin'][
|
||||
session_name
|
||||
],
|
||||
'limit': self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name],
|
||||
}
|
||||
|
||||
self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name] = temp_dict
|
||||
|
||||
@@ -9,10 +9,7 @@ class HttpApiConfigMigration(migration.Migration):
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return (
|
||||
'http-api' not in self.ap.system_cfg.data
|
||||
or 'persistence' not in self.ap.system_cfg.data
|
||||
)
|
||||
return 'http-api' not in self.ap.system_cfg.data or 'persistence' not in self.ap.system_cfg.data
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
|
||||
@@ -11,8 +11,7 @@ class DifyAPITimeoutParamsMigration(migration.Migration):
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return (
|
||||
'timeout' not in self.ap.provider_cfg.data['dify-service-api']['chat']
|
||||
or 'timeout'
|
||||
not in self.ap.provider_cfg.data['dify-service-api']['workflow']
|
||||
or 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['workflow']
|
||||
or 'agent' not in self.ap.provider_cfg.data['dify-service-api']
|
||||
)
|
||||
|
||||
|
||||
@@ -10,9 +10,7 @@ class SiliconFlowConfigMigration(migration.Migration):
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
|
||||
return (
|
||||
'siliconflow-chat-completions' not in self.ap.provider_cfg.data['requester']
|
||||
)
|
||||
return 'siliconflow-chat-completions' not in self.ap.provider_cfg.data['requester']
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
|
||||
@@ -13,17 +13,12 @@ class DifyThinkingConfigMigration(migration.Migration):
|
||||
if 'options' not in self.ap.provider_cfg.data['dify-service-api']:
|
||||
return True
|
||||
|
||||
if (
|
||||
'convert-thinking-tips'
|
||||
not in self.ap.provider_cfg.data['dify-service-api']['options']
|
||||
):
|
||||
if 'convert-thinking-tips' not in self.ap.provider_cfg.data['dify-service-api']['options']:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
self.ap.provider_cfg.data['dify-service-api']['options'] = {
|
||||
'convert-thinking-tips': 'plain'
|
||||
}
|
||||
self.ap.provider_cfg.data['dify-service-api']['options'] = {'convert-thinking-tips': 'plain'}
|
||||
await self.ap.provider_cfg.dump_config()
|
||||
|
||||
@@ -24,8 +24,6 @@ class GewechatFileUrlConfigMigration(migration.Migration):
|
||||
if adapter['adapter'] == 'gewechat':
|
||||
if 'gewechat_file_url' not in adapter:
|
||||
parsed_url = urlparse(adapter['gewechat_url'])
|
||||
adapter['gewechat_file_url'] = (
|
||||
f'{parsed_url.scheme}://{parsed_url.hostname}:2532'
|
||||
)
|
||||
adapter['gewechat_file_url'] = f'{parsed_url.scheme}://{parsed_url.hostname}:2532'
|
||||
|
||||
await self.ap.platform_cfg.dump_config()
|
||||
|
||||
@@ -3,24 +3,23 @@ from __future__ import annotations
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("tg-dingtalk-markdown", 38)
|
||||
@migration.migration_class('tg-dingtalk-markdown', 38)
|
||||
class TgDingtalkMarkdownMigration(migration.Migration):
|
||||
"""迁移"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
|
||||
|
||||
for adapter in self.ap.platform_cfg.data['platform-adapters']:
|
||||
if adapter['adapter'] in ['dingtalk','telegram']:
|
||||
if adapter['adapter'] in ['dingtalk', 'telegram']:
|
||||
if 'markdown_card' not in adapter:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
for adapter in self.ap.platform_cfg.data['platform-adapters']:
|
||||
if adapter['adapter'] in ['dingtalk','telegram']:
|
||||
if adapter['adapter'] in ['dingtalk', 'telegram']:
|
||||
if 'markdown_card' not in adapter:
|
||||
adapter['markdown_card'] = False
|
||||
await self.ap.platform_cfg.dump_config()
|
||||
|
||||
@@ -3,20 +3,19 @@ from __future__ import annotations
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("modelscope-config-completion", 39)
|
||||
@migration.migration_class('modelscope-config-completion', 39)
|
||||
class ModelScopeConfigCompletionMigration(migration.Migration):
|
||||
"""ModelScope配置迁移
|
||||
"""
|
||||
"""ModelScope配置迁移"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移
|
||||
"""
|
||||
return 'modelscope-chat-completions' not in self.ap.provider_cfg.data['requester'] \
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return (
|
||||
'modelscope-chat-completions' not in self.ap.provider_cfg.data['requester']
|
||||
or 'modelscope' not in self.ap.provider_cfg.data['keys']
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移
|
||||
"""
|
||||
"""执行迁移"""
|
||||
if 'modelscope-chat-completions' not in self.ap.provider_cfg.data['requester']:
|
||||
self.ap.provider_cfg.data['requester']['modelscope-chat-completions'] = {
|
||||
'base-url': 'https://api-inference.modelscope.cn/v1',
|
||||
|
||||
@@ -3,20 +3,19 @@ from __future__ import annotations
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("ppio-config", 40)
|
||||
@migration.migration_class('ppio-config', 40)
|
||||
class PPIOConfigMigration(migration.Migration):
|
||||
"""PPIO配置迁移
|
||||
"""
|
||||
"""PPIO配置迁移"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移
|
||||
"""
|
||||
return 'ppio-chat-completions' not in self.ap.provider_cfg.data['requester'] \
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return (
|
||||
'ppio-chat-completions' not in self.ap.provider_cfg.data['requester']
|
||||
or 'ppio' not in self.ap.provider_cfg.data['keys']
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移
|
||||
"""
|
||||
"""执行迁移"""
|
||||
if 'ppio-chat-completions' not in self.ap.provider_cfg.data['requester']:
|
||||
self.ap.provider_cfg.data['requester']['ppio-chat-completions'] = {
|
||||
'base-url': 'https://api.ppinfra.com/v3/openai',
|
||||
|
||||
@@ -35,9 +35,7 @@ class TaskContext:
|
||||
if action is not None:
|
||||
self.set_current_action(action)
|
||||
|
||||
self._log(
|
||||
f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} | {self.current_action} | {msg}'
|
||||
)
|
||||
self._log(f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} | {self.current_action} | {msg}')
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {'current_action': self.current_action, 'log': self.log}
|
||||
@@ -104,9 +102,7 @@ class TaskWrapper:
|
||||
name: str = '',
|
||||
label: str = '',
|
||||
context: TaskContext = None,
|
||||
scopes: list[core_entities.LifecycleControlScope] = [
|
||||
core_entities.LifecycleControlScope.APPLICATION
|
||||
],
|
||||
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
|
||||
):
|
||||
self.id = TaskWrapper._id_index
|
||||
TaskWrapper._id_index += 1
|
||||
@@ -141,7 +137,9 @@ class TaskWrapper:
|
||||
exception_traceback = 'Traceback (most recent call last):\n'
|
||||
|
||||
for frame in self.task_stack:
|
||||
exception_traceback += f' File "{frame.f_code.co_filename}", line {frame.f_lineno}, in {frame.f_code.co_name}\n'
|
||||
exception_traceback += (
|
||||
f' File "{frame.f_code.co_filename}", line {frame.f_lineno}, in {frame.f_code.co_name}\n'
|
||||
)
|
||||
|
||||
exception_traceback += f' {self.assume_exception().__str__()}\n'
|
||||
|
||||
@@ -156,13 +154,9 @@ class TaskWrapper:
|
||||
'runtime': {
|
||||
'done': self.task.done(),
|
||||
'state': self.task._state,
|
||||
'exception': self.assume_exception().__str__()
|
||||
if self.assume_exception() is not None
|
||||
else None,
|
||||
'exception': self.assume_exception().__str__() if self.assume_exception() is not None else None,
|
||||
'exception_traceback': exception_traceback,
|
||||
'result': self.assume_result().__str__()
|
||||
if self.assume_result() is not None
|
||||
else None,
|
||||
'result': self.assume_result().__str__() if self.assume_result() is not None else None,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -191,13 +185,9 @@ class AsyncTaskManager:
|
||||
name: str = '',
|
||||
label: str = '',
|
||||
context: TaskContext = None,
|
||||
scopes: list[core_entities.LifecycleControlScope] = [
|
||||
core_entities.LifecycleControlScope.APPLICATION
|
||||
],
|
||||
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
|
||||
) -> TaskWrapper:
|
||||
wrapper = TaskWrapper(
|
||||
self.ap, coro, task_type, kind, name, label, context, scopes
|
||||
)
|
||||
wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context, scopes)
|
||||
self.tasks.append(wrapper)
|
||||
return wrapper
|
||||
|
||||
@@ -208,9 +198,7 @@ class AsyncTaskManager:
|
||||
name: str = '',
|
||||
label: str = '',
|
||||
context: TaskContext = None,
|
||||
scopes: list[core_entities.LifecycleControlScope] = [
|
||||
core_entities.LifecycleControlScope.APPLICATION
|
||||
],
|
||||
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
|
||||
) -> TaskWrapper:
|
||||
return self.create_task(coro, 'user', kind, name, label, context, scopes)
|
||||
|
||||
@@ -225,9 +213,7 @@ class AsyncTaskManager:
|
||||
type: str = None,
|
||||
) -> dict:
|
||||
return {
|
||||
'tasks': [
|
||||
t.to_dict() for t in self.tasks if type is None or t.task_type == type
|
||||
],
|
||||
'tasks': [t.to_dict() for t in self.tasks if type is None or t.task_type == type],
|
||||
'id_index': TaskWrapper._id_index,
|
||||
}
|
||||
|
||||
|
||||
@@ -114,9 +114,7 @@ class Component(pydantic.BaseModel):
|
||||
_execution: Execution
|
||||
"""组件执行"""
|
||||
|
||||
def __init__(
|
||||
self, owner: str, manifest: typing.Dict[str, typing.Any], rel_path: str
|
||||
):
|
||||
def __init__(self, owner: str, manifest: typing.Dict[str, typing.Any], rel_path: str):
|
||||
super().__init__(
|
||||
owner=owner,
|
||||
manifest=manifest,
|
||||
@@ -125,19 +123,12 @@ class Component(pydantic.BaseModel):
|
||||
)
|
||||
self._metadata = Metadata(**manifest['metadata'])
|
||||
self._spec = manifest['spec']
|
||||
self._execution = (
|
||||
Execution(**manifest['execution']) if 'execution' in manifest else None
|
||||
)
|
||||
self._execution = Execution(**manifest['execution']) if 'execution' in manifest else None
|
||||
|
||||
@classmethod
|
||||
def is_component_manifest(cls, manifest: typing.Dict[str, typing.Any]) -> bool:
|
||||
"""判断是否为组件清单"""
|
||||
return (
|
||||
'apiVersion' in manifest
|
||||
and 'kind' in manifest
|
||||
and 'metadata' in manifest
|
||||
and 'spec' in manifest
|
||||
)
|
||||
return 'apiVersion' in manifest and 'kind' in manifest and 'metadata' in manifest and 'spec' in manifest
|
||||
|
||||
@property
|
||||
def kind(self) -> str:
|
||||
@@ -200,9 +191,7 @@ class ComponentDiscoveryEngine:
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
def load_component_manifest(
|
||||
self, path: str, owner: str = 'builtin', no_save: bool = False
|
||||
) -> Component | None:
|
||||
def load_component_manifest(self, path: str, owner: str = 'builtin', no_save: bool = False) -> Component | None:
|
||||
"""加载组件清单"""
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
manifest = yaml.safe_load(f)
|
||||
@@ -229,18 +218,12 @@ class ComponentDiscoveryEngine:
|
||||
if depth > max_depth:
|
||||
return
|
||||
for file in os.listdir(path):
|
||||
if (not os.path.isdir(os.path.join(path, file))) and (
|
||||
file.endswith('.yaml') or file.endswith('.yml')
|
||||
):
|
||||
comp = self.load_component_manifest(
|
||||
os.path.join(path, file), owner, no_save
|
||||
)
|
||||
if (not os.path.isdir(os.path.join(path, file))) and (file.endswith('.yaml') or file.endswith('.yml')):
|
||||
comp = self.load_component_manifest(os.path.join(path, file), owner, no_save)
|
||||
if comp is not None:
|
||||
components.append(comp)
|
||||
elif os.path.isdir(os.path.join(path, file)):
|
||||
recursive_load_component_manifests_in_dir(
|
||||
os.path.join(path, file), depth + 1
|
||||
)
|
||||
recursive_load_component_manifests_in_dir(os.path.join(path, file), depth + 1)
|
||||
|
||||
recursive_load_component_manifests_in_dir(path)
|
||||
return components
|
||||
@@ -259,18 +242,12 @@ class ComponentDiscoveryEngine:
|
||||
for dir in group['fromDirs']:
|
||||
path = dir['path']
|
||||
max_depth = dir['maxDepth'] if 'maxDepth' in dir else 1
|
||||
components.extend(
|
||||
self.load_component_manifests_in_dir(
|
||||
path, owner, no_save, max_depth
|
||||
)
|
||||
)
|
||||
components.extend(self.load_component_manifests_in_dir(path, owner, no_save, max_depth))
|
||||
return components
|
||||
|
||||
def discover_blueprint(self, blueprint_manifest_path: str, owner: str = 'builtin'):
|
||||
"""发现蓝图"""
|
||||
blueprint_manifest = self.load_component_manifest(
|
||||
blueprint_manifest_path, owner, no_save=True
|
||||
)
|
||||
blueprint_manifest = self.load_component_manifest(blueprint_manifest_path, owner, no_save=True)
|
||||
if blueprint_manifest is None:
|
||||
raise ValueError(f'Invalid blueprint manifest: {blueprint_manifest_path}')
|
||||
assert blueprint_manifest.kind == 'Blueprint', '`Kind` must be `Blueprint`'
|
||||
@@ -297,9 +274,7 @@ class ComponentDiscoveryEngine:
|
||||
return []
|
||||
return self.components[kind]
|
||||
|
||||
def find_components(
|
||||
self, kind: str, component_list: typing.List[Component]
|
||||
) -> typing.List[Component]:
|
||||
def find_components(self, kind: str, component_list: typing.List[Component]) -> typing.List[Component]:
|
||||
"""查找组件"""
|
||||
result: typing.List[Component] = []
|
||||
for component in component_list:
|
||||
|
||||
@@ -16,9 +16,7 @@ class Bot(Base):
|
||||
enable = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False)
|
||||
use_pipeline_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
use_pipeline_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
created_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()
|
||||
)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
nullable=False,
|
||||
|
||||
@@ -16,9 +16,7 @@ class LLMModel(Base):
|
||||
api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
|
||||
abilities = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=[])
|
||||
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
|
||||
created_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()
|
||||
)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
nullable=False,
|
||||
|
||||
@@ -11,9 +11,7 @@ class LegacyPipeline(Base):
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
description = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
created_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()
|
||||
)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
nullable=False,
|
||||
@@ -35,9 +33,7 @@ class PipelineRunRecord(Base):
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
pipeline_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
status = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
created_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()
|
||||
)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
nullable=False,
|
||||
|
||||
@@ -13,9 +13,7 @@ class PluginSetting(Base):
|
||||
enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True)
|
||||
priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
|
||||
config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict)
|
||||
created_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()
|
||||
)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
nullable=False,
|
||||
|
||||
@@ -9,9 +9,7 @@ class User(Base):
|
||||
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True)
|
||||
user = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
password = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
created_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()
|
||||
)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
nullable=False,
|
||||
|
||||
@@ -11,6 +11,4 @@ class SQLiteDatabaseManager(database.BaseDatabaseManager):
|
||||
|
||||
async def initialize(self) -> None:
|
||||
sqlite_path = 'data/langbot.db'
|
||||
self.engine = sqlalchemy_asyncio.create_async_engine(
|
||||
f'sqlite+aiosqlite:///{sqlite_path}'
|
||||
)
|
||||
self.engine = sqlalchemy_asyncio.create_async_engine(f'sqlite+aiosqlite:///{sqlite_path}')
|
||||
|
||||
@@ -58,24 +58,18 @@ class PersistenceManager:
|
||||
for item in metadata.initial_metadata:
|
||||
# check if the item exists
|
||||
result = await self.execute_async(
|
||||
sqlalchemy.select(metadata.Metadata).where(
|
||||
metadata.Metadata.key == item['key']
|
||||
)
|
||||
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == item['key'])
|
||||
)
|
||||
row = result.first()
|
||||
if row is None:
|
||||
await self.execute_async(
|
||||
sqlalchemy.insert(metadata.Metadata).values(item)
|
||||
)
|
||||
await self.execute_async(sqlalchemy.insert(metadata.Metadata).values(item))
|
||||
|
||||
# write default pipeline
|
||||
result = await self.execute_async(sqlalchemy.select(pipeline.LegacyPipeline))
|
||||
if result.first() is None:
|
||||
self.ap.logger.info('Creating default pipeline...')
|
||||
|
||||
pipeline_config = json.load(
|
||||
open('templates/default-pipeline-config.json', 'r', encoding='utf-8')
|
||||
)
|
||||
pipeline_config = json.load(open('templates/default-pipeline-config.json', 'r', encoding='utf-8'))
|
||||
|
||||
pipeline_data = {
|
||||
'uuid': str(uuid.uuid4()),
|
||||
@@ -87,16 +81,12 @@ class PersistenceManager:
|
||||
'config': pipeline_config,
|
||||
}
|
||||
|
||||
await self.execute_async(
|
||||
sqlalchemy.insert(pipeline.LegacyPipeline).values(pipeline_data)
|
||||
)
|
||||
await self.execute_async(sqlalchemy.insert(pipeline.LegacyPipeline).values(pipeline_data))
|
||||
# =================================
|
||||
|
||||
# run migrations
|
||||
database_version = await self.execute_async(
|
||||
sqlalchemy.select(metadata.Metadata).where(
|
||||
metadata.Metadata.key == 'database_version'
|
||||
)
|
||||
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
|
||||
)
|
||||
|
||||
database_version = int(database_version.fetchone()[1])
|
||||
@@ -122,17 +112,11 @@ class PersistenceManager:
|
||||
.values({'value': str(migration_instance.number)})
|
||||
)
|
||||
last_migration_number = migration_instance.number
|
||||
self.ap.logger.info(
|
||||
f'Migration {migration_instance.number} completed.'
|
||||
)
|
||||
self.ap.logger.info(f'Migration {migration_instance.number} completed.')
|
||||
|
||||
self.ap.logger.info(
|
||||
f'Successfully upgraded database to version {last_migration_number}.'
|
||||
)
|
||||
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
|
||||
|
||||
async def execute_async(
|
||||
self, *args, **kwargs
|
||||
) -> sqlalchemy.engine.cursor.CursorResult:
|
||||
async def execute_async(self, *args, **kwargs) -> sqlalchemy.engine.cursor.CursorResult:
|
||||
async with self.get_db_engine().connect() as conn:
|
||||
result = await conn.execute(*args, **kwargs)
|
||||
await conn.commit()
|
||||
@@ -141,9 +125,7 @@ class PersistenceManager:
|
||||
def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
|
||||
return self.db.get_engine()
|
||||
|
||||
def serialize_model(
|
||||
self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base
|
||||
) -> dict:
|
||||
def serialize_model(self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base) -> dict:
|
||||
return {
|
||||
column.name: getattr(data, column.name)
|
||||
if not isinstance(getattr(data, column.name), (datetime.datetime))
|
||||
|
||||
@@ -14,9 +14,7 @@ class BanSessionCheckStage(stage.PipelineStage):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
pass
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
found = False
|
||||
|
||||
mode = query.pipeline_config['trigger']['access-control']['mode']
|
||||
@@ -41,11 +39,7 @@ class BanSessionCheckStage(stage.PipelineStage):
|
||||
ctn = not found
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE
|
||||
if ctn
|
||||
else entities.ResultType.INTERRUPT,
|
||||
result_type=entities.ResultType.CONTINUE if ctn else entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}'
|
||||
if not ctn
|
||||
else '',
|
||||
console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}' if not ctn else '',
|
||||
)
|
||||
|
||||
@@ -65,9 +65,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
"""
|
||||
|
||||
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
for filter in self.filter_chain:
|
||||
if filter_entities.EnableStage.PRE in filter.enable_stages:
|
||||
@@ -86,13 +84,9 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
|
||||
message = result.replacement
|
||||
|
||||
query.message_chain = platform_message.MessageChain(
|
||||
platform_message.Plain(message)
|
||||
)
|
||||
query.message_chain = platform_message.MessageChain(platform_message.Plain(message))
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
async def _post_process(
|
||||
self,
|
||||
@@ -103,9 +97,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter
|
||||
"""
|
||||
if query.pipeline_config['safety']['content-filter']['scope'] == 'income-msg':
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
message = message.strip()
|
||||
for filter in self.filter_chain:
|
||||
@@ -127,13 +119,9 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
|
||||
query.resp_messages[-1].content = message
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理"""
|
||||
if stage_inst_name == 'PreContentFilterStage':
|
||||
contain_non_text = False
|
||||
@@ -147,9 +135,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
|
||||
if contain_non_text:
|
||||
self.ap.logger.debug('消息中包含非文本消息,跳过内容过滤器检查。')
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
return await self._pre_process(str(query.message_chain).strip(), query)
|
||||
elif stage_inst_name == 'PostContentFilterStage':
|
||||
@@ -162,8 +148,6 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
self.ap.logger.debug(
|
||||
'resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。'
|
||||
)
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')
|
||||
|
||||
@@ -60,9 +60,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(
|
||||
self, query: core_entities.Query, message: str = None, image_url=None
|
||||
) -> entities.FilterResult:
|
||||
async def process(self, query: core_entities.Query, message: str = None, image_url=None) -> entities.FilterResult:
|
||||
"""处理消息
|
||||
|
||||
分为前后阶段,具体取决于 enable_stages 的值。
|
||||
|
||||
@@ -21,19 +21,13 @@ class BaiduCloudExamine(filter_model.ContentFilter):
|
||||
BAIDU_EXAMINE_TOKEN_URL,
|
||||
params={
|
||||
'grant_type': 'client_credentials',
|
||||
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine'][
|
||||
'api-key'
|
||||
],
|
||||
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine'][
|
||||
'api-secret'
|
||||
],
|
||||
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'],
|
||||
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret'],
|
||||
},
|
||||
) as resp:
|
||||
return (await resp.json())['access_token']
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, message: str
|
||||
) -> entities.FilterResult:
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
BAIDU_EXAMINE_URL.format(await self._get_token()),
|
||||
|
||||
@@ -13,9 +13,7 @@ class BanWordFilter(filter_model.ContentFilter):
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, message: str
|
||||
) -> entities.FilterResult:
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
found = False
|
||||
|
||||
for word in self.ap.sensitive_meta.data['words']:
|
||||
@@ -31,9 +29,7 @@ class BanWordFilter(filter_model.ContentFilter):
|
||||
self.ap.sensitive_meta.data['mask'] * len(match[i]),
|
||||
)
|
||||
else:
|
||||
message = message.replace(
|
||||
match[i], self.ap.sensitive_meta.data['mask_word']
|
||||
)
|
||||
message = message.replace(match[i], self.ap.sensitive_meta.data['mask_word'])
|
||||
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS,
|
||||
|
||||
@@ -16,9 +16,7 @@ class ContentIgnore(filter_model.ContentFilter):
|
||||
entities.EnableStage.PRE,
|
||||
]
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, message: str
|
||||
) -> entities.FilterResult:
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
|
||||
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
|
||||
if message.startswith(rule):
|
||||
|
||||
@@ -16,9 +16,7 @@ class Controller:
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.semaphore = asyncio.Semaphore(
|
||||
self.ap.instance_config.data['concurrency']['pipeline']
|
||||
)
|
||||
self.semaphore = asyncio.Semaphore(self.ap.instance_config.data['concurrency']['pipeline'])
|
||||
|
||||
async def consumer(self):
|
||||
"""事件处理循环"""
|
||||
@@ -32,9 +30,7 @@ class Controller:
|
||||
|
||||
for query in queries:
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
self.ap.logger.debug(
|
||||
f'Checking query {query} session {session}'
|
||||
)
|
||||
self.ap.logger.debug(f'Checking query {query} session {session}')
|
||||
|
||||
if not session.semaphore.locked():
|
||||
selected_query = query
|
||||
@@ -55,22 +51,16 @@ class Controller:
|
||||
# find pipeline
|
||||
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
|
||||
# Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected.
|
||||
bot = await self.ap.platform_mgr.get_bot_by_uuid(
|
||||
selected_query.bot_uuid
|
||||
)
|
||||
bot = await self.ap.platform_mgr.get_bot_by_uuid(selected_query.bot_uuid)
|
||||
if bot:
|
||||
pipeline = (
|
||||
await self.ap.pipeline_mgr.get_pipeline_by_uuid(
|
||||
bot.bot_entity.use_pipeline_uuid
|
||||
)
|
||||
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(
|
||||
bot.bot_entity.use_pipeline_uuid
|
||||
)
|
||||
if pipeline:
|
||||
await pipeline.run(selected_query)
|
||||
|
||||
async with self.ap.query_pool:
|
||||
(
|
||||
await self.ap.sess_mgr.get_session(selected_query)
|
||||
).semaphore.release()
|
||||
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
|
||||
# 通知其他协程,有新的请求可以处理了
|
||||
self.ap.query_pool.condition.notify_all()
|
||||
|
||||
|
||||
@@ -47,9 +47,7 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
'未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。'
|
||||
)
|
||||
|
||||
pipeline_config['output']['long-text-processing'][
|
||||
'strategy'
|
||||
] = 'forward'
|
||||
pipeline_config['output']['long-text-processing']['strategy'] = 'forward'
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
self.ap.logger.error(
|
||||
@@ -58,9 +56,7 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
)
|
||||
)
|
||||
|
||||
pipeline_config['output']['long-text-processing']['strategy'] = (
|
||||
'forward'
|
||||
)
|
||||
pipeline_config['output']['long-text-processing']['strategy'] = 'forward'
|
||||
|
||||
for strategy_cls in strategy.preregistered_strategies:
|
||||
if strategy_cls.name == config['strategy']:
|
||||
@@ -71,9 +67,7 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
|
||||
await self.strategy_impl.initialize()
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
# 检查是否包含非 Plain 组件
|
||||
contains_non_plain = False
|
||||
|
||||
@@ -89,11 +83,7 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
> query.pipeline_config['output']['long-text-processing']['threshold']
|
||||
):
|
||||
query.resp_message_chain[-1] = platform_message.MessageChain(
|
||||
await self.strategy_impl.process(
|
||||
str(query.resp_message_chain[-1]), query
|
||||
)
|
||||
await self.strategy_impl.process(str(query.resp_message_chain[-1]), query)
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
@@ -13,9 +13,7 @@ Forward = platform_message.Forward
|
||||
|
||||
@strategy_model.strategy_class('forward')
|
||||
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
|
||||
async def process(
|
||||
self, message: str, query: core_entities.Query
|
||||
) -> list[platform_message.MessageComponent]:
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
display = ForwardMessageDiaplay(
|
||||
title='群聊的聊天记录',
|
||||
brief='[聊天记录]',
|
||||
|
||||
@@ -27,18 +27,14 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
encoding='utf-8',
|
||||
)
|
||||
|
||||
async def process(
|
||||
self, message: str, query: core_entities.Query
|
||||
) -> list[platform_message.MessageComponent]:
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
img_path = self.text_to_image(
|
||||
text_str=message,
|
||||
save_as='temp/{}.png'.format(int(time.time())),
|
||||
query=query,
|
||||
)
|
||||
|
||||
compressed_path, size = self.compress_image(
|
||||
img_path, outfile='temp/{}_compressed.png'.format(int(time.time()))
|
||||
)
|
||||
compressed_path, size = self.compress_image(img_path, outfile='temp/{}_compressed.png'.format(int(time.time())))
|
||||
|
||||
with open(compressed_path, 'rb') as f:
|
||||
img = f.read()
|
||||
@@ -165,10 +161,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
numbers = self.indexNumber(rest_text)
|
||||
|
||||
for number in numbers:
|
||||
if (
|
||||
number[1] < point < number[1] + len(number[0])
|
||||
and number[1] != 0
|
||||
):
|
||||
if number[1] < point < number[1] + len(number[0]) and number[1] != 0:
|
||||
point = number[1]
|
||||
break
|
||||
|
||||
@@ -181,9 +174,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
else:
|
||||
continue
|
||||
# 准备画布
|
||||
img = Image.new(
|
||||
'RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255)
|
||||
)
|
||||
img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255))
|
||||
draw = ImageDraw.Draw(img, mode='RGBA')
|
||||
|
||||
self.ap.logger.debug('正在绘制图片...')
|
||||
|
||||
@@ -49,9 +49,7 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(
|
||||
self, message: str, query: core_entities.Query
|
||||
) -> list[platform_message.MessageComponent]:
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
"""处理长文本
|
||||
|
||||
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法
|
||||
|
||||
@@ -29,12 +29,8 @@ class ConversationMessageTruncator(stage.PipelineStage):
|
||||
else:
|
||||
raise ValueError(f'未知的截断器: {use_method}')
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理"""
|
||||
query = await self.trun.truncate(query)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
@@ -79,26 +79,20 @@ class RuntimePipeline:
|
||||
query.pipeline_config = self.pipeline_entity.config
|
||||
await self.process_query(query)
|
||||
|
||||
async def _check_output(
|
||||
self, query: entities.Query, result: pipeline_entities.StageProcessResult
|
||||
):
|
||||
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
|
||||
"""检查输出"""
|
||||
if result.user_notice:
|
||||
# 处理str类型
|
||||
|
||||
if isinstance(result.user_notice, str):
|
||||
result.user_notice = platform_message.MessageChain(
|
||||
platform_message.Plain(result.user_notice)
|
||||
)
|
||||
result.user_notice = platform_message.MessageChain(platform_message.Plain(result.user_notice))
|
||||
elif isinstance(result.user_notice, list):
|
||||
result.user_notice = platform_message.MessageChain(*result.user_notice)
|
||||
|
||||
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
|
||||
query.message_event, platform_events.GroupMessage
|
||||
):
|
||||
result.user_notice.insert(
|
||||
0, platform_message.At(query.message_event.sender.id)
|
||||
)
|
||||
result.user_notice.insert(0, platform_message.At(query.message_event.sender.id))
|
||||
|
||||
await query.adapter.reply_message(
|
||||
message_source=query.message_event,
|
||||
@@ -150,37 +144,25 @@ class RuntimePipeline:
|
||||
result = await result
|
||||
|
||||
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} processed query {query} res {result}'
|
||||
)
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} res {result}')
|
||||
await self._check_output(query, result)
|
||||
|
||||
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} interrupted query {query}'
|
||||
)
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query}')
|
||||
break
|
||||
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = result.new_query
|
||||
elif isinstance(result, typing.AsyncGenerator): # 生成器
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} processed query {query} gen'
|
||||
)
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} gen')
|
||||
|
||||
async for sub_result in result:
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} processed query {query} res {sub_result}'
|
||||
)
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} res {sub_result}')
|
||||
await self._check_output(query, sub_result)
|
||||
|
||||
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(
|
||||
f'Stage {stage_container.inst_name} interrupted query {query}'
|
||||
)
|
||||
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query}')
|
||||
break
|
||||
elif (
|
||||
sub_result.result_type == pipeline_entities.ResultType.CONTINUE
|
||||
):
|
||||
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = sub_result.new_query
|
||||
await self._execute_from_stage(i + 1, query)
|
||||
break
|
||||
@@ -214,12 +196,8 @@ class RuntimePipeline:
|
||||
|
||||
await self._execute_from_stage(0, query)
|
||||
except Exception as e:
|
||||
inst_name = (
|
||||
query.current_stage.inst_name if query.current_stage else 'unknown'
|
||||
)
|
||||
self.ap.logger.error(
|
||||
f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}'
|
||||
)
|
||||
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
|
||||
self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}')
|
||||
self.ap.logger.debug(f'Traceback: {traceback.format_exc()}')
|
||||
finally:
|
||||
self.ap.logger.debug(f'Query {query} processed')
|
||||
@@ -241,18 +219,14 @@ class PipelineManager:
|
||||
self.pipelines = []
|
||||
|
||||
async def initialize(self):
|
||||
self.stage_dict = {
|
||||
name: cls for name, cls in stage.preregistered_stages.items()
|
||||
}
|
||||
self.stage_dict = {name: cls for name, cls in stage.preregistered_stages.items()}
|
||||
|
||||
await self.load_pipelines_from_db()
|
||||
|
||||
async def load_pipelines_from_db(self):
|
||||
self.ap.logger.info('Loading pipelines from db...')
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline)
|
||||
)
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
|
||||
|
||||
pipelines = result.all()
|
||||
|
||||
@@ -267,20 +241,14 @@ class PipelineManager:
|
||||
| dict,
|
||||
):
|
||||
if isinstance(pipeline_entity, sqlalchemy.Row):
|
||||
pipeline_entity = persistence_pipeline.LegacyPipeline(
|
||||
**pipeline_entity._mapping
|
||||
)
|
||||
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity._mapping)
|
||||
elif isinstance(pipeline_entity, dict):
|
||||
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
|
||||
|
||||
# initialize stage containers according to pipeline_entity.stages
|
||||
stage_containers: list[StageInstContainer] = []
|
||||
for stage_name in pipeline_entity.stages:
|
||||
stage_containers.append(
|
||||
StageInstContainer(
|
||||
inst_name=stage_name, inst=self.stage_dict[stage_name](self.ap)
|
||||
)
|
||||
)
|
||||
stage_containers.append(StageInstContainer(inst_name=stage_name, inst=self.stage_dict[stage_name](self.ap)))
|
||||
|
||||
for stage_container in stage_containers:
|
||||
await stage_container.inst.initialize(pipeline_entity.config)
|
||||
|
||||
@@ -44,9 +44,7 @@ class PreProcessor(stage.PipelineStage):
|
||||
query.use_llm_model = conversation.use_llm_model
|
||||
|
||||
query.use_funcs = (
|
||||
conversation.use_funcs
|
||||
if query.use_llm_model.model_entity.abilities.__contains__('tool_call')
|
||||
else None
|
||||
conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('tool_call') else None
|
||||
)
|
||||
|
||||
query.variables = {
|
||||
@@ -59,10 +57,9 @@ class PreProcessor(stage.PipelineStage):
|
||||
|
||||
# Check if this model supports vision, if not, remove all images
|
||||
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
|
||||
if (
|
||||
query.pipeline_config['ai']['runner']['runner'] == 'local-agent'
|
||||
and not query.use_llm_model.model_entity.abilities.__contains__('vision')
|
||||
):
|
||||
if query.pipeline_config['ai']['runner'][
|
||||
'runner'
|
||||
] == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'):
|
||||
for msg in query.messages:
|
||||
if isinstance(msg.content, list):
|
||||
for me in msg.content:
|
||||
@@ -78,14 +75,11 @@ class PreProcessor(stage.PipelineStage):
|
||||
content_list.append(llm_entities.ContentElement.from_text(me.text))
|
||||
plain_text += me.text
|
||||
elif isinstance(me, platform_message.Image):
|
||||
if (
|
||||
query.pipeline_config['ai']['runner']['runner'] != 'local-agent'
|
||||
or query.use_llm_model.model_entity.abilities.__contains__('vision')
|
||||
):
|
||||
if query.pipeline_config['ai']['runner'][
|
||||
'runner'
|
||||
] != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__('vision'):
|
||||
if me.base64 is not None:
|
||||
content_list.append(
|
||||
llm_entities.ContentElement.from_image_base64(me.base64)
|
||||
)
|
||||
content_list.append(llm_entities.ContentElement.from_image_base64(me.base64))
|
||||
|
||||
query.variables['user_message_text'] = plain_text
|
||||
|
||||
@@ -104,6 +98,4 @@ class PreProcessor(stage.PipelineStage):
|
||||
query.prompt.messages = event_ctx.event.default_prompt
|
||||
query.messages = event_ctx.event.prompt
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
@@ -49,13 +49,9 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
|
||||
query.resp_messages.append(mc)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
|
||||
else:
|
||||
if event_ctx.event.alter is not None:
|
||||
# if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
|
||||
@@ -69,34 +65,24 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
runner = r(self.ap, query.pipeline_config)
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}'
|
||||
)
|
||||
raise ValueError(f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}')
|
||||
|
||||
async for result in runner.run(query):
|
||||
query.resp_messages.append(result)
|
||||
|
||||
self.ap.logger.info(
|
||||
f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}'
|
||||
)
|
||||
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
|
||||
|
||||
if result.content is not None:
|
||||
text_length += len(result.content)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
query.session.using_conversation.messages.append(query.user_message)
|
||||
query.session.using_conversation.messages.extend(query.resp_messages)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(
|
||||
f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}'
|
||||
)
|
||||
self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}')
|
||||
|
||||
hide_exception_info = query.pipeline_config['output']['misc'][
|
||||
'hide-exception'
|
||||
]
|
||||
hide_exception_info = query.pipeline_config['output']['misc']['hide-exception']
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
|
||||
@@ -21,10 +21,7 @@ class CommandHandler(handler.MessageHandler):
|
||||
|
||||
privilege = 1
|
||||
|
||||
if (
|
||||
f'{query.launcher_type.value}_{query.launcher_id}'
|
||||
in self.ap.instance_config.data['admins']
|
||||
):
|
||||
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
|
||||
privilege = 2
|
||||
|
||||
spt = command_text.split(' ')
|
||||
@@ -54,25 +51,17 @@ class CommandHandler(handler.MessageHandler):
|
||||
|
||||
query.resp_messages.append(mc)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
|
||||
|
||||
else:
|
||||
if event_ctx.event.alter is not None:
|
||||
query.message_chain = platform_message.MessageChain(
|
||||
[platform_message.Plain(event_ctx.event.alter)]
|
||||
)
|
||||
query.message_chain = platform_message.MessageChain([platform_message.Plain(event_ctx.event.alter)])
|
||||
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
async for ret in self.ap.cmd_mgr.execute(
|
||||
command_text=command_text, query=query, session=session
|
||||
):
|
||||
async for ret in self.ap.cmd_mgr.execute(command_text=command_text, query=query, session=session):
|
||||
if ret.error is not None:
|
||||
query.resp_messages.append(
|
||||
llm_entities.Message(
|
||||
@@ -81,13 +70,9 @@ class CommandHandler(handler.MessageHandler):
|
||||
)
|
||||
)
|
||||
|
||||
self.ap.logger.info(
|
||||
f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}'
|
||||
)
|
||||
self.ap.logger.info(f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}')
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
elif ret.text is not None or ret.image_url is not None:
|
||||
content: list[llm_entities.ContentElement] = []
|
||||
|
||||
@@ -95,9 +80,7 @@ class CommandHandler(handler.MessageHandler):
|
||||
content.append(llm_entities.ContentElement.from_text(ret.text))
|
||||
|
||||
if ret.image_url is not None:
|
||||
content.append(
|
||||
llm_entities.ContentElement.from_image_url(ret.image_url)
|
||||
)
|
||||
content.append(llm_entities.ContentElement.from_image_url(ret.image_url))
|
||||
|
||||
query.resp_messages.append(
|
||||
llm_entities.Message(
|
||||
@@ -108,10 +91,6 @@ class CommandHandler(handler.MessageHandler):
|
||||
|
||||
self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}')
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
|
||||
|
||||
@@ -72,9 +72,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
if count >= limitation:
|
||||
if query.pipeline_config['safety']['rate-limit']['strategy'] == 'drop':
|
||||
return False
|
||||
elif (
|
||||
query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait'
|
||||
):
|
||||
elif query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait':
|
||||
# 等待下一窗口
|
||||
await asyncio.sleep(window_size - time.time() % window_size)
|
||||
|
||||
|
||||
@@ -15,9 +15,7 @@ from ...core import entities as core_entities
|
||||
class SendResponseBackStage(stage.PipelineStage):
|
||||
"""发送响应消息"""
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理"""
|
||||
|
||||
random_range = (
|
||||
@@ -34,9 +32,7 @@ class SendResponseBackStage(stage.PipelineStage):
|
||||
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
|
||||
query.message_event, platform_events.GroupMessage
|
||||
):
|
||||
query.resp_message_chain[-1].insert(
|
||||
0, platform_message.At(query.message_event.sender.id)
|
||||
)
|
||||
query.resp_message_chain[-1].insert(0, platform_message.At(query.message_event.sender.id))
|
||||
|
||||
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
|
||||
|
||||
@@ -46,6 +42,4 @@ class SendResponseBackStage(stage.PipelineStage):
|
||||
quote_origin=quote_origin,
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
@@ -32,13 +32,9 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
await rule_inst.initialize()
|
||||
self.rule_matchers.append(rule_inst)
|
||||
|
||||
async def process(
|
||||
self, query: core_entities.Query, stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
if query.launcher_type.value != 'group': # 只处理群消息
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
rules = query.pipeline_config['trigger']['group-respond-rules']
|
||||
|
||||
@@ -49,9 +45,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
# use_rule = rules[str(query.launcher_id)]
|
||||
|
||||
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
|
||||
res = await rule_matcher.match(
|
||||
str(query.message_chain), query.message_chain, use_rule, query
|
||||
)
|
||||
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)
|
||||
if res.matching:
|
||||
query.message_chain = res.replacement
|
||||
|
||||
@@ -60,6 +54,4 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
new_query=query,
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT, new_query=query
|
||||
)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
|
||||
|
||||
@@ -16,10 +16,7 @@ class AtBotRule(rule_model.GroupRespondRule):
|
||||
rule_dict: dict,
|
||||
query: core_entities.Query,
|
||||
) -> entities.RuleJudgeResult:
|
||||
if (
|
||||
message_chain.has(platform_message.At(query.adapter.bot_account_id))
|
||||
and rule_dict['at']
|
||||
):
|
||||
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
|
||||
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
|
||||
|
||||
if message_chain.has(
|
||||
|
||||
@@ -18,6 +18,4 @@ class RandomRespRule(rule_model.GroupRespondRule):
|
||||
) -> entities.RuleJudgeResult:
|
||||
random_rate = rule_dict['random']
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
matching=random.random() < random_rate, replacement=message_chain
|
||||
)
|
||||
return entities.RuleJudgeResult(matching=random.random() < random_rate, replacement=message_chain)
|
||||
|
||||
@@ -34,29 +34,19 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
if isinstance(query.resp_messages[-1], platform_message.MessageChain):
|
||||
query.resp_message_chain.append(query.resp_messages[-1])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
else:
|
||||
if query.resp_messages[-1].role == 'command':
|
||||
query.resp_message_chain.append(
|
||||
query.resp_messages[-1].get_content_platform_message_chain(
|
||||
prefix_text='[bot] '
|
||||
)
|
||||
query.resp_messages[-1].get_content_platform_message_chain(prefix_text='[bot] ')
|
||||
)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
elif query.resp_messages[-1].role == 'plugin':
|
||||
query.resp_message_chain.append(
|
||||
query.resp_messages[-1].get_content_platform_message_chain()
|
||||
)
|
||||
query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain())
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE, new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
if query.resp_messages[-1].role == 'assistant':
|
||||
result = query.resp_messages[-1]
|
||||
@@ -77,9 +67,7 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=[
|
||||
fc.function.name for fc in result.tool_calls
|
||||
]
|
||||
funcs_called=[fc.function.name for fc in result.tool_calls]
|
||||
if result.tool_calls is not None
|
||||
else [],
|
||||
query=query,
|
||||
@@ -92,36 +80,26 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
)
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
query.resp_message_chain.append(
|
||||
platform_message.MessageChain(event_ctx.event.reply)
|
||||
)
|
||||
query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
|
||||
|
||||
else:
|
||||
query.resp_message_chain.append(
|
||||
result.get_content_platform_message_chain()
|
||||
)
|
||||
query.resp_message_chain.append(result.get_content_platform_message_chain())
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query,
|
||||
)
|
||||
|
||||
if (
|
||||
result.tool_calls is not None and len(result.tool_calls) > 0
|
||||
): # 有函数调用
|
||||
if result.tool_calls is not None and len(result.tool_calls) > 0: # 有函数调用
|
||||
function_names = [tc.function.name for tc in result.tool_calls]
|
||||
|
||||
reply_text = f'调用函数 {".".join(function_names)}...'
|
||||
|
||||
query.resp_message_chain.append(
|
||||
platform_message.MessageChain(
|
||||
[platform_message.Plain(reply_text)]
|
||||
)
|
||||
platform_message.MessageChain([platform_message.Plain(reply_text)])
|
||||
)
|
||||
|
||||
if query.pipeline_config['output']['misc'][
|
||||
'track-function-calls'
|
||||
]:
|
||||
if query.pipeline_config['output']['misc']['track-function-calls']:
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
@@ -131,9 +109,7 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=[
|
||||
fc.function.name for fc in result.tool_calls
|
||||
]
|
||||
funcs_called=[fc.function.name for fc in result.tool_calls]
|
||||
if result.tool_calls is not None
|
||||
else [],
|
||||
query=query,
|
||||
@@ -148,16 +124,12 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
query.resp_message_chain.append(
|
||||
platform_message.MessageChain(
|
||||
event_ctx.event.reply
|
||||
)
|
||||
platform_message.MessageChain(event_ctx.event.reply)
|
||||
)
|
||||
|
||||
else:
|
||||
query.resp_message_chain.append(
|
||||
platform_message.MessageChain(
|
||||
[platform_message.Plain(reply_text)]
|
||||
)
|
||||
platform_message.MessageChain([platform_message.Plain(reply_text)])
|
||||
)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
|
||||
@@ -32,9 +32,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
"""主动发送消息
|
||||
|
||||
Args:
|
||||
@@ -66,9 +64,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_message.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_message.Event, MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None],
|
||||
):
|
||||
"""注册事件监听器
|
||||
|
||||
@@ -81,9 +77,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_message.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_message.Event, MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None],
|
||||
):
|
||||
"""注销事件监听器
|
||||
|
||||
|
||||
@@ -132,14 +132,10 @@ class PlatformManager:
|
||||
self.adapter_dict = {}
|
||||
|
||||
async def initialize(self):
|
||||
self.adapter_components = self.ap.discover.get_components_by_kind(
|
||||
'MessagePlatformAdapter'
|
||||
)
|
||||
self.adapter_components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter')
|
||||
adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] = {}
|
||||
for component in self.adapter_components:
|
||||
adapter_dict[component.metadata.name] = (
|
||||
component.get_python_component_class()
|
||||
)
|
||||
adapter_dict[component.metadata.name] = component.get_python_component_class()
|
||||
self.adapter_dict = adapter_dict
|
||||
|
||||
await self.load_bots_from_db()
|
||||
@@ -152,9 +148,7 @@ class PlatformManager:
|
||||
|
||||
self.bots = []
|
||||
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_bot.Bot)
|
||||
)
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot))
|
||||
|
||||
bots = result.all()
|
||||
|
||||
@@ -172,13 +166,9 @@ class PlatformManager:
|
||||
elif isinstance(bot_entity, dict):
|
||||
bot_entity = persistence_bot.Bot(**bot_entity)
|
||||
|
||||
adapter_inst = self.adapter_dict[bot_entity.adapter](
|
||||
bot_entity.adapter_config, self.ap
|
||||
)
|
||||
adapter_inst = self.adapter_dict[bot_entity.adapter](bot_entity.adapter_config, self.ap)
|
||||
|
||||
runtime_bot = RuntimeBot(
|
||||
ap=self.ap, bot_entity=bot_entity, adapter=adapter_inst
|
||||
)
|
||||
runtime_bot = RuntimeBot(ap=self.ap, bot_entity=bot_entity, adapter=adapter_inst)
|
||||
|
||||
await runtime_bot.initialize()
|
||||
|
||||
@@ -209,9 +199,7 @@ class PlatformManager:
|
||||
return component.to_plain_dict()
|
||||
return None
|
||||
|
||||
def get_available_adapter_manifest_by_name(
|
||||
self, name: str
|
||||
) -> engine.Component | None:
|
||||
def get_available_adapter_manifest_by_name(self, name: str) -> engine.Component | None:
|
||||
for component in self.adapter_components:
|
||||
if component.metadata.name == name:
|
||||
return component
|
||||
|
||||
@@ -58,13 +58,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
||||
msg_list.append(aiocqhttp.MessageSegment.record(msg.path))
|
||||
elif type(msg) is platform_message.Forward:
|
||||
for node in msg.node_list:
|
||||
msg_list.extend(
|
||||
(
|
||||
await AiocqhttpMessageConverter.yiri2target(
|
||||
node.message_chain
|
||||
)
|
||||
)[0]
|
||||
)
|
||||
msg_list.extend((await AiocqhttpMessageConverter.yiri2target(node.message_chain))[0])
|
||||
|
||||
else:
|
||||
msg_list.append(aiocqhttp.MessageSegment.text(str(msg)))
|
||||
@@ -77,9 +71,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
||||
|
||||
yiri_msg_list = []
|
||||
|
||||
yiri_msg_list.append(
|
||||
platform_message.Source(id=message_id, time=datetime.datetime.now())
|
||||
)
|
||||
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
|
||||
|
||||
for msg in message:
|
||||
if msg.type == 'at':
|
||||
@@ -94,14 +86,8 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
||||
elif msg.type == 'text':
|
||||
yiri_msg_list.append(platform_message.Plain(text=msg.data['text']))
|
||||
elif msg.type == 'image':
|
||||
image_base64, image_format = await image.qq_image_url_to_base64(
|
||||
msg.data['url']
|
||||
)
|
||||
yiri_msg_list.append(
|
||||
platform_message.Image(
|
||||
base64=f'data:image/{image_format};base64,{image_base64}'
|
||||
)
|
||||
)
|
||||
image_base64, image_format = await image.qq_image_url_to_base64(msg.data['url'])
|
||||
yiri_msg_list.append(platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}'))
|
||||
|
||||
chain = platform_message.MessageChain(yiri_msg_list)
|
||||
|
||||
@@ -115,9 +101,7 @@ class AiocqhttpEventConverter(adapter.EventConverter):
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(event: aiocqhttp.Event):
|
||||
yiri_chain = await AiocqhttpMessageConverter.target2yiri(
|
||||
event.message, event.message_id
|
||||
)
|
||||
yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id)
|
||||
|
||||
if event.message_type == 'group':
|
||||
permission = 'MEMBER'
|
||||
@@ -137,9 +121,7 @@ class AiocqhttpEventConverter(adapter.EventConverter):
|
||||
name=event.sender['nickname'],
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title=event.sender['title']
|
||||
if 'title' in event.sender
|
||||
else '',
|
||||
special_title=event.sender['title'] if 'title' in event.sender else '',
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
@@ -191,9 +173,7 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
||||
else:
|
||||
self.bot = aiocqhttp.CQHttp()
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0]
|
||||
|
||||
if target_type == 'group':
|
||||
@@ -207,14 +187,10 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
aiocq_event = await AiocqhttpEventConverter.yiri2target(
|
||||
message_source, self.bot_account_id
|
||||
)
|
||||
aiocq_event = await AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
|
||||
aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0]
|
||||
if quote_origin:
|
||||
aiocq_msg = (
|
||||
aiocqhttp.MessageSegment.reply(aiocq_event.message_id) + aiocq_msg
|
||||
)
|
||||
aiocq_msg = aiocqhttp.MessageSegment.reply(aiocq_event.message_id) + aiocq_msg
|
||||
|
||||
return await self.bot.send(aiocq_event, aiocq_msg)
|
||||
|
||||
@@ -224,16 +200,12 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
):
|
||||
async def on_message(event: aiocqhttp.Event):
|
||||
self.bot_account_id = event.self_id
|
||||
try:
|
||||
return await callback(
|
||||
await self.event_converter.target2yiri(event), self
|
||||
)
|
||||
return await callback(await self.event_converter.target2yiri(event), self)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -245,9 +217,7 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
):
|
||||
return super().unregister_listener(event_type, callback)
|
||||
|
||||
|
||||
@@ -22,9 +22,7 @@ class DingTalkMessageConverter(adapter.MessageConverter):
|
||||
async def target2yiri(event: DingTalkEvent, bot_name: str):
|
||||
yiri_msg_list = []
|
||||
yiri_msg_list.append(
|
||||
platform_message.Source(
|
||||
id=event.incoming_message.message_id, time=datetime.datetime.now()
|
||||
)
|
||||
platform_message.Source(id=event.incoming_message.message_id, time=datetime.datetime.now())
|
||||
)
|
||||
|
||||
for atUser in event.incoming_message.at_users:
|
||||
@@ -133,9 +131,7 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
||||
content = await DingTalkMessageConverter.yiri2target(message)
|
||||
await self.bot.send_message(content, incoming_message)
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
content = await DingTalkMessageConverter.yiri2target(message)
|
||||
if target_type == 'person':
|
||||
await self.bot.send_proactive_message_to_one(target_id, content)
|
||||
@@ -145,16 +141,12 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
):
|
||||
async def on_message(event: DingTalkEvent):
|
||||
try:
|
||||
return await callback(
|
||||
await self.event_converter.target2yiri(
|
||||
event, self.config['robot_name']
|
||||
),
|
||||
await self.event_converter.target2yiri(event, self.config['robot_name']),
|
||||
self,
|
||||
)
|
||||
except Exception:
|
||||
@@ -174,8 +166,6 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
||||
async def unregister_listener(
|
||||
self,
|
||||
event_type: type,
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
|
||||
):
|
||||
return super().unregister_listener(event_type, callback)
|
||||
|
||||
@@ -45,9 +45,7 @@ class DiscordMessageConverter(adapter.MessageConverter):
|
||||
with open(ele.path, 'rb') as f:
|
||||
image_bytes = f.read()
|
||||
|
||||
image_files.append(
|
||||
discord.File(fp=image_bytes, filename=f'{uuid.uuid4()}.png')
|
||||
)
|
||||
image_files.append(discord.File(fp=image_bytes, filename=f'{uuid.uuid4()}.png'))
|
||||
elif isinstance(ele, platform_message.Plain):
|
||||
text_string += ele.text
|
||||
elif isinstance(ele, platform_message.Forward):
|
||||
@@ -65,9 +63,7 @@ class DiscordMessageConverter(adapter.MessageConverter):
|
||||
async def target2yiri(message: discord.Message) -> platform_message.MessageChain:
|
||||
lb_msg_list = []
|
||||
|
||||
msg_create_time = datetime.datetime.fromtimestamp(
|
||||
int(message.created_at.timestamp())
|
||||
)
|
||||
msg_create_time = datetime.datetime.fromtimestamp(int(message.created_at.timestamp()))
|
||||
|
||||
lb_msg_list.append(platform_message.Source(id=message.id, time=msg_create_time))
|
||||
|
||||
@@ -97,11 +93,7 @@ class DiscordMessageConverter(adapter.MessageConverter):
|
||||
else:
|
||||
mid_at_component.append(platform_message.At(target=mid_at[2:-1]))
|
||||
|
||||
return (
|
||||
text_element_recur(text_split[0])
|
||||
+ mid_at_component
|
||||
+ text_element_recur(text_split[1])
|
||||
)
|
||||
return text_element_recur(text_split[0]) + mid_at_component + text_element_recur(text_split[1])
|
||||
else:
|
||||
return [platform_message.Plain(text=text_ele)]
|
||||
|
||||
@@ -114,11 +106,7 @@ class DiscordMessageConverter(adapter.MessageConverter):
|
||||
image_data = await response.read()
|
||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||||
image_format = response.headers['Content-Type']
|
||||
element_list.append(
|
||||
platform_message.Image(
|
||||
base64=f'data:{image_format};base64,{image_base64}'
|
||||
)
|
||||
)
|
||||
element_list.append(platform_message.Image(base64=f'data:{image_format};base64,{image_base64}'))
|
||||
|
||||
return platform_message.MessageChain(element_list)
|
||||
|
||||
@@ -208,9 +196,7 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
self.bot = MyClient(intents=intents, **args)
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
pass
|
||||
|
||||
async def reply_message(
|
||||
@@ -243,18 +229,14 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
):
|
||||
self.listeners[event_type] = callback
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
):
|
||||
self.listeners.pop(event_type)
|
||||
|
||||
|
||||
@@ -40,14 +40,10 @@ class GewechatMessageConverter(adapter.MessageConverter):
|
||||
content_list.append({'type': 'image', 'image': component.url})
|
||||
|
||||
elif isinstance(component, platform_message.Voice):
|
||||
content_list.append(
|
||||
{'type': 'voice', 'url': component.url, 'length': component.length}
|
||||
)
|
||||
content_list.append({'type': 'voice', 'url': component.url, 'length': component.length})
|
||||
elif isinstance(component, platform_message.Forward):
|
||||
for node in component.node_list:
|
||||
content_list.extend(
|
||||
await GewechatMessageConverter.yiri2target(node.message_chain)
|
||||
)
|
||||
content_list.extend(await GewechatMessageConverter.yiri2target(node.message_chain))
|
||||
content_list.append({'type': 'image', 'image': component.url})
|
||||
elif isinstance(component, platform_message.WeChatMiniPrograms):
|
||||
content_list.append(
|
||||
@@ -88,44 +84,26 @@ class GewechatMessageConverter(adapter.MessageConverter):
|
||||
}
|
||||
)
|
||||
elif isinstance(component, platform_message.WeChatForwardLink):
|
||||
content_list.append(
|
||||
{'type': 'WeChatForwardLink', 'xml_data': component.xml_data}
|
||||
)
|
||||
content_list.append({'type': 'WeChatForwardLink', 'xml_data': component.xml_data})
|
||||
elif isinstance(component, platform_message.Voice):
|
||||
content_list.append(
|
||||
{'type': 'voice', 'url': component.url, 'length': component.length}
|
||||
)
|
||||
content_list.append({'type': 'voice', 'url': component.url, 'length': component.length})
|
||||
elif isinstance(component, platform_message.WeChatForwardImage):
|
||||
content_list.append(
|
||||
{'type': 'WeChatForwardImage', 'xml_data': component.xml_data}
|
||||
)
|
||||
content_list.append({'type': 'WeChatForwardImage', 'xml_data': component.xml_data})
|
||||
elif isinstance(component, platform_message.WeChatForwardFile):
|
||||
content_list.append(
|
||||
{'type': 'WeChatForwardFile', 'xml_data': component.xml_data}
|
||||
)
|
||||
content_list.append({'type': 'WeChatForwardFile', 'xml_data': component.xml_data})
|
||||
elif isinstance(component, platform_message.WeChatAppMsg):
|
||||
content_list.append(
|
||||
{'type': 'WeChatAppMsg', 'app_msg': component.app_msg}
|
||||
)
|
||||
content_list.append({'type': 'WeChatAppMsg', 'app_msg': component.app_msg})
|
||||
# 引用消息转发
|
||||
elif isinstance(component, platform_message.WeChatForwardQuote):
|
||||
content_list.append(
|
||||
{'type': 'WeChatAppMsg', 'app_msg': component.app_msg}
|
||||
)
|
||||
content_list.append({'type': 'WeChatAppMsg', 'app_msg': component.app_msg})
|
||||
elif isinstance(component, platform_message.Forward):
|
||||
for node in component.node_list:
|
||||
if node.message_chain:
|
||||
content_list.extend(
|
||||
await GewechatMessageConverter.yiri2target(
|
||||
node.message_chain
|
||||
)
|
||||
)
|
||||
content_list.extend(await GewechatMessageConverter.yiri2target(node.message_chain))
|
||||
|
||||
return content_list
|
||||
|
||||
async def target2yiri(
|
||||
self, message: dict, bot_account_id: str
|
||||
) -> platform_message.MessageChain:
|
||||
async def target2yiri(self, message: dict, bot_account_id: str) -> platform_message.MessageChain:
|
||||
"""外部消息转平台消息"""
|
||||
# 数据预处理
|
||||
message_list = []
|
||||
@@ -163,28 +141,20 @@ class GewechatMessageConverter(adapter.MessageConverter):
|
||||
|
||||
return platform_message.MessageChain(message_list)
|
||||
|
||||
async def _handler_text(
|
||||
self, message: Optional[dict], content_no_preifx: str
|
||||
) -> platform_message.MessageChain:
|
||||
async def _handler_text(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
|
||||
"""处理文本消息 (msg_type=1)"""
|
||||
if message and self._is_group_message(message):
|
||||
pattern = r'@\S{1,20}'
|
||||
content_no_preifx = re.sub(pattern, '', content_no_preifx)
|
||||
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Plain(content_no_preifx)]
|
||||
)
|
||||
return platform_message.MessageChain([platform_message.Plain(content_no_preifx)])
|
||||
|
||||
async def _handler_image(
|
||||
self, message: Optional[dict], content_no_preifx: str
|
||||
) -> platform_message.MessageChain:
|
||||
async def _handler_image(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
|
||||
"""处理图像消息 (msg_type=3)"""
|
||||
try:
|
||||
image_xml = content_no_preifx
|
||||
if not image_xml:
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Unknown('[图片内容为空]')]
|
||||
)
|
||||
return platform_message.MessageChain([platform_message.Unknown('[图片内容为空]')])
|
||||
|
||||
base64_str, image_format = await image.get_gewechat_image_base64(
|
||||
gewechat_url=self.config['gewechat_url'],
|
||||
@@ -196,21 +166,15 @@ class GewechatMessageConverter(adapter.MessageConverter):
|
||||
)
|
||||
|
||||
elements = [
|
||||
platform_message.Image(
|
||||
base64=f'data:image/{image_format};base64,{base64_str}'
|
||||
),
|
||||
platform_message.Image(base64=f'data:image/{image_format};base64,{base64_str}'),
|
||||
platform_message.WeChatForwardImage(xml_data=image_xml), # 微信消息转发
|
||||
]
|
||||
return platform_message.MessageChain(elements)
|
||||
except Exception as e:
|
||||
print(f'处理图片失败: {str(e)}')
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Unknown('[图片处理失败]')]
|
||||
)
|
||||
return platform_message.MessageChain([platform_message.Unknown('[图片处理失败]')])
|
||||
|
||||
async def _handler_voice(
|
||||
self, message: Optional[dict], content_no_preifx: str
|
||||
) -> platform_message.MessageChain:
|
||||
async def _handler_voice(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
|
||||
"""处理语音消息 (msg_type=34)"""
|
||||
message_List = []
|
||||
try:
|
||||
@@ -223,9 +187,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
|
||||
return platform_message.MessageChain(message_List)
|
||||
|
||||
# 转换为平台支持的语音格式(如 Silk 格式)
|
||||
voice_element = platform_message.Voice(
|
||||
base64=f'data:audio/silk;base64,{audio_base64}'
|
||||
)
|
||||
voice_element = platform_message.Voice(base64=f'data:audio/silk;base64,{audio_base64}')
|
||||
message_List.append(voice_element)
|
||||
|
||||
except KeyError as e:
|
||||
@@ -237,9 +199,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
|
||||
|
||||
return platform_message.MessageChain(message_List)
|
||||
|
||||
async def _handler_compound(
|
||||
self, message: Optional[dict], content_no_preifx: str
|
||||
) -> platform_message.MessageChain:
|
||||
async def _handler_compound(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
|
||||
"""处理复合消息 (msg_type=49),根据子类型分派"""
|
||||
try:
|
||||
xml_data = ET.fromstring(content_no_preifx)
|
||||
@@ -254,33 +214,21 @@ class GewechatMessageConverter(adapter.MessageConverter):
|
||||
'6': self._handler_compound_file,
|
||||
'33': self._handler_compound_mini_program,
|
||||
'36': self._handler_compound_mini_program,
|
||||
'2000': partial(
|
||||
self._handler_compound_unsupported, text='[转账消息]'
|
||||
),
|
||||
'2001': partial(
|
||||
self._handler_compound_unsupported, text='[红包消息]'
|
||||
),
|
||||
'51': partial(
|
||||
self._handler_compound_unsupported, text='[视频号消息]'
|
||||
),
|
||||
'2000': partial(self._handler_compound_unsupported, text='[转账消息]'),
|
||||
'2001': partial(self._handler_compound_unsupported, text='[红包消息]'),
|
||||
'51': partial(self._handler_compound_unsupported, text='[视频号消息]'),
|
||||
}
|
||||
|
||||
handler = sub_handler_map.get(
|
||||
data_type, self._handler_compound_unsupported
|
||||
)
|
||||
handler = sub_handler_map.get(data_type, self._handler_compound_unsupported)
|
||||
return await handler(
|
||||
message=message, # 原始msg
|
||||
xml_data=xml_data, # xml数据
|
||||
)
|
||||
else:
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Unknown(text=content_no_preifx)]
|
||||
)
|
||||
return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)])
|
||||
except Exception as e:
|
||||
print(f'解析复合消息失败: {str(e)}')
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Unknown(text=content_no_preifx)]
|
||||
)
|
||||
return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)])
|
||||
|
||||
async def _handler_compound_quote(
|
||||
self, message: Optional[dict], xml_data: ET.Element
|
||||
@@ -296,9 +244,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
|
||||
user_data = appmsg_data.findtext('.//title') or ''
|
||||
quote_data = appmsg_data.find('.//refermsg').findtext('.//content')
|
||||
message_list.append(
|
||||
platform_message.WeChatForwardQuote(
|
||||
app_msg=ET.tostring(appmsg_data, encoding='unicode')
|
||||
)
|
||||
platform_message.WeChatForwardQuote(app_msg=ET.tostring(appmsg_data, encoding='unicode'))
|
||||
)
|
||||
# quote_data原始的消息
|
||||
if quote_data:
|
||||
@@ -311,22 +257,14 @@ class GewechatMessageConverter(adapter.MessageConverter):
|
||||
# 引用消息展开
|
||||
quote_data_xml = ET.fromstring(quote_data)
|
||||
if quote_data_xml.find('img'):
|
||||
quote_data_message_list.extend(
|
||||
await self._handler_image(None, quote_data)
|
||||
)
|
||||
quote_data_message_list.extend(await self._handler_image(None, quote_data))
|
||||
elif quote_data_xml.find('voicemsg'):
|
||||
quote_data_message_list.extend(
|
||||
await self._handler_voice(None, quote_data)
|
||||
)
|
||||
quote_data_message_list.extend(await self._handler_voice(None, quote_data))
|
||||
elif quote_data_xml.find('videomsg'):
|
||||
quote_data_message_list.extend(
|
||||
await self._handler_default(None, quote_data)
|
||||
) # 先不处理
|
||||
quote_data_message_list.extend(await self._handler_default(None, quote_data)) # 先不处理
|
||||
else:
|
||||
# appmsg
|
||||
quote_data_message_list.extend(
|
||||
await self._handler_compound(None, quote_data)
|
||||
)
|
||||
quote_data_message_list.extend(await self._handler_compound(None, quote_data))
|
||||
except Exception as e:
|
||||
print(f'处理引用消息异常 expcetion:{e}')
|
||||
quote_data_message_list.append(platform_message.Plain(quote_data))
|
||||
@@ -351,18 +289,12 @@ class GewechatMessageConverter(adapter.MessageConverter):
|
||||
# print(f"quote_message_chain plain [msg_type={comp.type}][message={comp.text}]")
|
||||
return platform_message.MessageChain(message_list)
|
||||
|
||||
async def _handler_compound_file(
|
||||
self, message: dict, xml_data: ET.Element
|
||||
) -> platform_message.MessageChain:
|
||||
async def _handler_compound_file(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain:
|
||||
"""处理文件消息 (data_type=6)"""
|
||||
xml_data_str = ET.tostring(xml_data, encoding='unicode')
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.WeChatForwardFile(xml_data=xml_data_str)]
|
||||
)
|
||||
return platform_message.MessageChain([platform_message.WeChatForwardFile(xml_data=xml_data_str)])
|
||||
|
||||
async def _handler_compound_link(
|
||||
self, message: dict, xml_data: ET.Element
|
||||
) -> platform_message.MessageChain:
|
||||
async def _handler_compound_link(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain:
|
||||
"""处理链接消息(如公众号文章、外部网页)"""
|
||||
message_list = []
|
||||
try:
|
||||
@@ -381,9 +313,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
|
||||
# 转发消息
|
||||
xml_data_str = ET.tostring(xml_data, encoding='unicode')
|
||||
# print(xml_data_str)
|
||||
message_list.append(
|
||||
platform_message.WeChatForwardLink(xml_data=xml_data_str)
|
||||
)
|
||||
message_list.append(platform_message.WeChatForwardLink(xml_data=xml_data_str))
|
||||
except Exception as e:
|
||||
print(f'解析链接消息失败: {str(e)}')
|
||||
return platform_message.MessageChain(message_list)
|
||||
@@ -393,21 +323,15 @@ class GewechatMessageConverter(adapter.MessageConverter):
|
||||
) -> platform_message.MessageChain:
|
||||
"""处理小程序消息(如小程序卡片、服务通知)"""
|
||||
xml_data_str = ET.tostring(xml_data, encoding='unicode')
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)]
|
||||
)
|
||||
return platform_message.MessageChain([platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)])
|
||||
|
||||
async def _handler_default(
|
||||
self, message: Optional[dict], content_no_preifx: str
|
||||
) -> platform_message.MessageChain:
|
||||
async def _handler_default(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
|
||||
"""处理未知消息类型"""
|
||||
if message:
|
||||
msg_type = message['Data']['MsgType']
|
||||
else:
|
||||
msg_type = ''
|
||||
return platform_message.MessageChain(
|
||||
[platform_message.Unknown(text=f'[未知消息类型 msg_type:{msg_type}]')]
|
||||
)
|
||||
return platform_message.MessageChain([platform_message.Unknown(text=f'[未知消息类型 msg_type:{msg_type}]')])
|
||||
|
||||
def _handler_compound_unsupported(
|
||||
self, message: dict, xml_data: str, text: Optional[str] = None
|
||||
@@ -416,11 +340,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
|
||||
if not text:
|
||||
text = f'[xml_data={xml_data}]'
|
||||
content_list = []
|
||||
content_list.append(
|
||||
platform_message.Unknown(
|
||||
text=f'[处理未支持复合消息类型[msg_type=49]|{text}'
|
||||
)
|
||||
)
|
||||
content_list.append(platform_message.Unknown(text=f'[处理未支持复合消息类型[msg_type=49]|{text}'))
|
||||
|
||||
return platform_message.MessageChain(content_list)
|
||||
|
||||
@@ -448,9 +368,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
|
||||
appmsg_data = xml_data.find('.//appmsg')
|
||||
tousername = message['Wxid']
|
||||
if appmsg_data: # 接收方: 所属微信的wxid
|
||||
quote_id = appmsg_data.find('.//refermsg').findtext(
|
||||
'.//chatusr'
|
||||
) # 引用消息的原发送者
|
||||
quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') # 引用消息的原发送者
|
||||
ats_bot = ats_bot or (quote_id == tousername)
|
||||
except Exception as e:
|
||||
print(f'_ats_bot got except: {e}')
|
||||
@@ -458,9 +376,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
|
||||
return ats_bot
|
||||
|
||||
# 提取一下content前面的sender_id, 和去掉前缀的内容
|
||||
def _extract_content_and_sender(
|
||||
self, raw_content: str
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
def _extract_content_and_sender(self, raw_content: str) -> Tuple[str, Optional[str]]:
|
||||
try:
|
||||
# 检查消息开头,如果有 wxid_sbitaz0mt65n22:\n 则删掉
|
||||
# add: 有些用户的wxid不是上述格式。换成user_name:
|
||||
@@ -490,21 +406,17 @@ class GewechatEventConverter(adapter.EventConverter):
|
||||
async def yiri2target(event: platform_events.MessageEvent) -> dict:
|
||||
pass
|
||||
|
||||
async def target2yiri(
|
||||
self, event: dict, bot_account_id: str
|
||||
) -> platform_events.MessageEvent:
|
||||
async def target2yiri(self, event: dict, bot_account_id: str) -> platform_events.MessageEvent:
|
||||
# print(event)
|
||||
# 排除自己发消息回调回答问题
|
||||
if event['Wxid'] == event['Data']['FromUserName']['string']:
|
||||
return None
|
||||
# 排除公众号以及微信团队消息
|
||||
if event['Data']['FromUserName']['string'].startswith('gh_') or event['Data'][
|
||||
'FromUserName'
|
||||
]['string'].startswith('weixin'):
|
||||
if event['Data']['FromUserName']['string'].startswith('gh_') or event['Data']['FromUserName'][
|
||||
'string'
|
||||
].startswith('weixin'):
|
||||
return None
|
||||
message_chain = await self.message_converter.target2yiri(
|
||||
copy.deepcopy(event), bot_account_id
|
||||
)
|
||||
message_chain = await self.message_converter.target2yiri(copy.deepcopy(event), bot_account_id)
|
||||
|
||||
if not message_chain:
|
||||
return None
|
||||
@@ -589,9 +501,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
return 'ok'
|
||||
elif 'TypeName' in data and data['TypeName'] == 'AddMsg':
|
||||
try:
|
||||
event = await self.event_converter.target2yiri(
|
||||
data.copy(), self.bot_account_id
|
||||
)
|
||||
event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -600,9 +510,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
return 'ok'
|
||||
|
||||
async def _handle_message(
|
||||
self, message: platform_message.MessageChain, target_id: str
|
||||
):
|
||||
async def _handle_message(self, message: platform_message.MessageChain, target_id: str):
|
||||
"""统一消息处理核心逻辑"""
|
||||
content_list = await self.message_converter.yiri2target(message)
|
||||
at_targets = [item['target'] for item in content_list if item['type'] == 'at']
|
||||
@@ -611,9 +519,9 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
at_targets = at_targets or []
|
||||
member_info = []
|
||||
if at_targets:
|
||||
member_info = self.bot.get_chatroom_member_detail(
|
||||
self.config['app_id'], target_id, at_targets[::-1]
|
||||
)['data']
|
||||
member_info = self.bot.get_chatroom_member_detail(self.config['app_id'], target_id, at_targets[::-1])[
|
||||
'data'
|
||||
]
|
||||
|
||||
# 处理消息组件
|
||||
for msg in content_list:
|
||||
@@ -694,9 +602,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}')
|
||||
continue
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
"""主动发送消息"""
|
||||
return await self._handle_message(message, target_id)
|
||||
|
||||
@@ -708,9 +614,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
):
|
||||
"""回复消息"""
|
||||
if message_source.source_platform_object:
|
||||
target_id = message_source.source_platform_object['Data']['FromUserName'][
|
||||
'string'
|
||||
]
|
||||
target_id = message_source.source_platform_object['Data']['FromUserName']['string']
|
||||
return await self._handle_message(message, target_id)
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
@@ -719,18 +623,14 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
):
|
||||
self.listeners[event_type] = callback
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -742,14 +642,10 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
||||
json={'app_id': self.config['app_id']},
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f'获取gewechat token失败: {await response.text()}'
|
||||
)
|
||||
raise Exception(f'获取gewechat token失败: {await response.text()}')
|
||||
self.config['token'] = (await response.json())['data']
|
||||
|
||||
self.bot = gewechat_client.GewechatClient(
|
||||
f'{self.config["gewechat_url"]}/v2/api', self.config['token']
|
||||
)
|
||||
self.bot = gewechat_client.GewechatClient(f'{self.config["gewechat_url"]}/v2/api', self.config['token'])
|
||||
|
||||
def gewechat_login_process():
|
||||
app_id, error_msg = self.bot.login(self.config['app_id'])
|
||||
|
||||
@@ -71,14 +71,10 @@ class LarkMessageConverter(adapter.MessageConverter):
|
||||
pending_paragraph.append({'tag': 'md', 'text': text})
|
||||
except UnicodeError:
|
||||
# If still fails, replace invalid characters
|
||||
text = msg.text.encode('utf-8', errors='replace').decode(
|
||||
'utf-8'
|
||||
)
|
||||
text = msg.text.encode('utf-8', errors='replace').decode('utf-8')
|
||||
pending_paragraph.append({'tag': 'md', 'text': text})
|
||||
elif isinstance(msg, platform_message.At):
|
||||
pending_paragraph.append(
|
||||
{'tag': 'at', 'user_id': msg.target, 'style': []}
|
||||
)
|
||||
pending_paragraph.append({'tag': 'at', 'user_id': msg.target, 'style': []})
|
||||
elif isinstance(msg, platform_message.AtAll):
|
||||
pending_paragraph.append({'tag': 'at', 'user_id': 'all', 'style': []})
|
||||
elif isinstance(msg, platform_message.Image):
|
||||
@@ -166,11 +162,7 @@ class LarkMessageConverter(adapter.MessageConverter):
|
||||
os.unlink(temp_file.name)
|
||||
elif isinstance(msg, platform_message.Forward):
|
||||
for node in msg.node_list:
|
||||
message_elements.extend(
|
||||
await LarkMessageConverter.yiri2target(
|
||||
node.message_chain, api_client
|
||||
)
|
||||
)
|
||||
message_elements.extend(await LarkMessageConverter.yiri2target(node.message_chain, api_client))
|
||||
|
||||
if pending_paragraph:
|
||||
message_elements.append(pending_paragraph)
|
||||
@@ -186,13 +178,9 @@ class LarkMessageConverter(adapter.MessageConverter):
|
||||
|
||||
lb_msg_list = []
|
||||
|
||||
msg_create_time = datetime.datetime.fromtimestamp(
|
||||
int(message.create_time) / 1000
|
||||
)
|
||||
msg_create_time = datetime.datetime.fromtimestamp(int(message.create_time) / 1000)
|
||||
|
||||
lb_msg_list.append(
|
||||
platform_message.Source(id=message.message_id, time=msg_create_time)
|
||||
)
|
||||
lb_msg_list.append(platform_message.Source(id=message.message_id, time=msg_create_time))
|
||||
|
||||
if message.message_type == 'text':
|
||||
element_list = []
|
||||
@@ -222,9 +210,7 @@ class LarkMessageConverter(adapter.MessageConverter):
|
||||
left_text = text_split[0]
|
||||
right_text = text_split[1]
|
||||
|
||||
new_list.extend(
|
||||
text_element_recur({'tag': 'text', 'text': left_text, 'style': []})
|
||||
)
|
||||
new_list.extend(text_element_recur({'tag': 'text', 'text': left_text, 'style': []}))
|
||||
|
||||
new_list.append(
|
||||
{
|
||||
@@ -235,15 +221,11 @@ class LarkMessageConverter(adapter.MessageConverter):
|
||||
}
|
||||
)
|
||||
|
||||
new_list.extend(
|
||||
text_element_recur({'tag': 'text', 'text': right_text, 'style': []})
|
||||
)
|
||||
new_list.extend(text_element_recur({'tag': 'text', 'text': right_text, 'style': []}))
|
||||
|
||||
return new_list
|
||||
|
||||
element_list = text_element_recur(
|
||||
{'tag': 'text', 'text': message_content['text'], 'style': []}
|
||||
)
|
||||
element_list = text_element_recur({'tag': 'text', 'text': message_content['text'], 'style': []})
|
||||
|
||||
message_content = {'title': '', 'content': element_list}
|
||||
|
||||
@@ -258,9 +240,7 @@ class LarkMessageConverter(adapter.MessageConverter):
|
||||
|
||||
message_content['content'] = new_list
|
||||
elif message.message_type == 'image':
|
||||
message_content['content'] = [
|
||||
{'tag': 'img', 'image_key': message_content['image_key'], 'style': []}
|
||||
]
|
||||
message_content['content'] = [{'tag': 'img', 'image_key': message_content['image_key'], 'style': []}]
|
||||
|
||||
for ele in message_content['content']:
|
||||
if ele['tag'] == 'text':
|
||||
@@ -278,9 +258,7 @@ class LarkMessageConverter(adapter.MessageConverter):
|
||||
.build()
|
||||
)
|
||||
|
||||
response: GetMessageResourceResponse = (
|
||||
await api_client.im.v1.message_resource.aget(request)
|
||||
)
|
||||
response: GetMessageResourceResponse = await api_client.im.v1.message_resource.aget(request)
|
||||
|
||||
if not response.success():
|
||||
raise Exception(
|
||||
@@ -292,11 +270,7 @@ class LarkMessageConverter(adapter.MessageConverter):
|
||||
|
||||
image_format = response.raw.headers['content-type']
|
||||
|
||||
lb_msg_list.append(
|
||||
platform_message.Image(
|
||||
base64=f'data:{image_format};base64,{image_base64}'
|
||||
)
|
||||
)
|
||||
lb_msg_list.append(platform_message.Image(base64=f'data:{image_format};base64,{image_base64}'))
|
||||
|
||||
return platform_message.MessageChain(lb_msg_list)
|
||||
|
||||
@@ -312,9 +286,7 @@ class LarkEventConverter(adapter.EventConverter):
|
||||
async def target2yiri(
|
||||
event: lark_oapi.im.v1.P2ImMessageReceiveV1, api_client: lark_oapi.Client
|
||||
) -> platform_events.Event:
|
||||
message_chain = await LarkMessageConverter.target2yiri(
|
||||
event.event.message, api_client
|
||||
)
|
||||
message_chain = await LarkMessageConverter.target2yiri(event.event.message, api_client)
|
||||
|
||||
if event.event.message.chat_type == 'p2p':
|
||||
return platform_events.FriendMessage(
|
||||
@@ -402,9 +374,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
p2v1.schema = context.schema
|
||||
if 'im.message.receive_v1' == type:
|
||||
try:
|
||||
event = await self.event_converter.target2yiri(
|
||||
p2v1, self.api_client
|
||||
)
|
||||
event = await self.event_converter.target2yiri(p2v1, self.api_client)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -425,26 +395,15 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
asyncio.create_task(on_message(event))
|
||||
|
||||
event_handler = (
|
||||
lark_oapi.EventDispatcherHandler.builder('', '')
|
||||
.register_p2_im_message_receive_v1(sync_on_message)
|
||||
.build()
|
||||
lark_oapi.EventDispatcherHandler.builder('', '').register_p2_im_message_receive_v1(sync_on_message).build()
|
||||
)
|
||||
|
||||
self.bot_account_id = config['bot_name']
|
||||
|
||||
self.bot = lark_oapi.ws.Client(
|
||||
config['app_id'], config['app_secret'], event_handler=event_handler
|
||||
)
|
||||
self.api_client = (
|
||||
lark_oapi.Client.builder()
|
||||
.app_id(config['app_id'])
|
||||
.app_secret(config['app_secret'])
|
||||
.build()
|
||||
)
|
||||
self.bot = lark_oapi.ws.Client(config['app_id'], config['app_secret'], event_handler=event_handler)
|
||||
self.api_client = lark_oapi.Client.builder().app_id(config['app_id']).app_secret(config['app_secret']).build()
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
pass
|
||||
|
||||
async def reply_message(
|
||||
@@ -455,9 +414,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
):
|
||||
# 不再需要了,因为message_id已经被包含到message_chain中
|
||||
# lark_event = await self.event_converter.yiri2target(message_source)
|
||||
lark_message = await self.message_converter.yiri2target(
|
||||
message, self.api_client
|
||||
)
|
||||
lark_message = await self.message_converter.yiri2target(message, self.api_client)
|
||||
|
||||
final_content = {
|
||||
'zh_cn': {
|
||||
@@ -480,9 +437,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
.build()
|
||||
)
|
||||
|
||||
response: ReplyMessageResponse = await self.api_client.im.v1.message.areply(
|
||||
request
|
||||
)
|
||||
response: ReplyMessageResponse = await self.api_client.im.v1.message.areply(request)
|
||||
|
||||
if not response.success():
|
||||
raise Exception(
|
||||
@@ -495,18 +450,14 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
):
|
||||
self.listeners[event_type] = callback
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
):
|
||||
self.listeners.pop(event_type)
|
||||
|
||||
|
||||
@@ -29,9 +29,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
||||
elif type(message_chain) is str:
|
||||
msg_list = [platform_message.Plain(message_chain)]
|
||||
else:
|
||||
raise Exception(
|
||||
'Unknown message type: ' + str(message_chain) + str(type(message_chain))
|
||||
)
|
||||
raise Exception('Unknown message type: ' + str(message_chain) + str(type(message_chain)))
|
||||
|
||||
nakuru_msg_list = []
|
||||
|
||||
@@ -63,9 +61,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
||||
# 遍历并转换
|
||||
for yiri_forward_node in yiri_forward_node_list:
|
||||
try:
|
||||
content_list = NakuruProjectMessageConverter.yiri2target(
|
||||
yiri_forward_node.message_chain
|
||||
)
|
||||
content_list = NakuruProjectMessageConverter.yiri2target(yiri_forward_node.message_chain)
|
||||
nakuru_forward_node = nkc.Node(
|
||||
name=yiri_forward_node.sender_name,
|
||||
uin=yiri_forward_node.sender_id,
|
||||
@@ -87,9 +83,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
||||
return nakuru_msg_list
|
||||
|
||||
@staticmethod
|
||||
def target2yiri(
|
||||
message_chain: typing.Any, message_id: int = -1
|
||||
) -> platform_message.MessageChain:
|
||||
def target2yiri(message_chain: typing.Any, message_id: int = -1) -> platform_message.MessageChain:
|
||||
"""将Yiri的消息链转换为YiriMirai的消息链"""
|
||||
assert type(message_chain) is list
|
||||
|
||||
@@ -97,9 +91,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
||||
import datetime
|
||||
|
||||
# 添加Source组件以标记message_id等信息
|
||||
yiri_msg_list.append(
|
||||
platform_message.Source(id=message_id, time=datetime.datetime.now())
|
||||
)
|
||||
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
|
||||
for component in message_chain:
|
||||
if type(component) is nkc.Plain:
|
||||
yiri_msg_list.append(platform_message.Plain(text=component.text))
|
||||
@@ -130,9 +122,7 @@ class NakuruProjectEventConverter(adapter_model.EventConverter):
|
||||
|
||||
@staticmethod
|
||||
def target2yiri(event: typing.Any) -> platform_events.Event:
|
||||
yiri_chain = NakuruProjectMessageConverter.target2yiri(
|
||||
event.message, event.message_id
|
||||
)
|
||||
yiri_chain = NakuruProjectMessageConverter.target2yiri(event.message, event.message_id)
|
||||
if type(event) is nakuru.FriendMessage: # 私聊消息事件
|
||||
return platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
@@ -206,9 +196,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
):
|
||||
task = None
|
||||
|
||||
converted_msg = (
|
||||
self.message_converter.yiri2target(message) if not converted else message
|
||||
)
|
||||
converted_msg = self.message_converter.yiri2target(message) if not converted else message
|
||||
|
||||
# 检查是否有转发消息
|
||||
has_forward = False
|
||||
@@ -250,13 +238,9 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
),
|
||||
)
|
||||
if type(message_source) is platform_events.GroupMessage:
|
||||
await self.send_message(
|
||||
'group', message_source.sender.group.id, message, converted=True
|
||||
)
|
||||
await self.send_message('group', message_source.sender.group.id, message, converted=True)
|
||||
elif type(message_source) is platform_events.FriendMessage:
|
||||
await self.send_message(
|
||||
'person', message_source.sender.id, message, converted=True
|
||||
)
|
||||
await self.send_message('person', message_source.sender.id, message, converted=True)
|
||||
else:
|
||||
raise Exception('Unknown message source type: ' + str(type(message_source)))
|
||||
|
||||
@@ -264,17 +248,13 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
import time
|
||||
|
||||
# 检查是否被禁言
|
||||
group_member_info = asyncio.run(
|
||||
self.bot.getGroupMemberInfo(group_id, self.bot_account_id)
|
||||
)
|
||||
group_member_info = asyncio.run(self.bot.getGroupMemberInfo(group_id, self.bot_account_id))
|
||||
return group_member_info.shut_up_timestamp > int(time.time())
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter_model.MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
|
||||
):
|
||||
try:
|
||||
source_cls = NakuruProjectEventConverter.yiri2target(event_type)
|
||||
@@ -301,9 +281,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter_model.MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
|
||||
):
|
||||
nakuru_event_name = self.event_converter.yiri2target(event_type).__name__
|
||||
|
||||
@@ -312,10 +290,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
# 从本对象的监听器列表中查找并删除
|
||||
target_wrapper = None
|
||||
for listener in self.listener_list:
|
||||
if (
|
||||
listener['event_type'] == event_type
|
||||
and listener['callable'] == callback
|
||||
):
|
||||
if listener['event_type'] == event_type and listener['callable'] == callback:
|
||||
target_wrapper = listener['wrapper']
|
||||
self.listener_list.remove(listener)
|
||||
break
|
||||
@@ -334,14 +309,8 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
import requests
|
||||
|
||||
resp = requests.get(
|
||||
url='http://{}:{}/get_login_info'.format(
|
||||
self.cfg['host'], self.cfg['http_port']
|
||||
),
|
||||
headers={
|
||||
'Authorization': 'Bearer ' + self.cfg['token']
|
||||
if 'token' in self.cfg
|
||||
else ''
|
||||
},
|
||||
url='http://{}:{}/get_login_info'.format(self.cfg['host'], self.cfg['http_port']),
|
||||
headers={'Authorization': 'Bearer ' + self.cfg['token'] if 'token' in self.cfg else ''},
|
||||
timeout=5,
|
||||
proxies=None,
|
||||
)
|
||||
@@ -349,9 +318,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
||||
raise Exception('go-cqhttp拒绝访问,请检查配置文件中nakuru适配器的配置')
|
||||
self.bot_account_id = int(resp.json()['data']['user_id'])
|
||||
except Exception:
|
||||
raise Exception(
|
||||
'获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确'
|
||||
)
|
||||
raise Exception('获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确')
|
||||
await self.bot._run()
|
||||
self.ap.logger.info('运行 Nakuru 适配器')
|
||||
while True:
|
||||
|
||||
@@ -25,9 +25,7 @@ class OAMessageConverter(adapter.MessageConverter):
|
||||
@staticmethod
|
||||
async def target2yiri(message: str, message_id=-1):
|
||||
yiri_msg_list = []
|
||||
yiri_msg_list.append(
|
||||
platform_message.Source(id=message_id, time=datetime.datetime.now())
|
||||
)
|
||||
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
|
||||
|
||||
yiri_msg_list.append(platform_message.Plain(text=message))
|
||||
chain = platform_message.MessageChain(yiri_msg_list)
|
||||
@@ -39,9 +37,7 @@ class OAEventConverter(adapter.EventConverter):
|
||||
@staticmethod
|
||||
async def target2yiri(event: OAEvent):
|
||||
if event.type == 'text':
|
||||
yiri_chain = await OAMessageConverter.target2yiri(
|
||||
event.message, event.message_id
|
||||
)
|
||||
yiri_chain = await OAMessageConverter.target2yiri(event.message, event.message_id)
|
||||
|
||||
friend = platform_entities.Friend(
|
||||
id=event.user_id,
|
||||
@@ -81,9 +77,7 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
|
||||
]
|
||||
missing_keys = [key for key in required_keys if key not in config]
|
||||
if missing_keys:
|
||||
raise ParamNotEnoughError(
|
||||
'微信公众号缺少相关配置项,请查看文档或联系管理员'
|
||||
)
|
||||
raise ParamNotEnoughError('微信公众号缺少相关配置项,请查看文档或联系管理员')
|
||||
|
||||
if self.config['Mode'] == 'drop':
|
||||
self.bot = OAClient(
|
||||
@@ -114,28 +108,20 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
|
||||
await self.bot.set_message(message_source.message_chain.message_id, content)
|
||||
elif isinstance(self.bot, OAClientForLongerResponse):
|
||||
from_user = message_source.sender.id
|
||||
await self.bot.set_message(
|
||||
from_user, message_source.message_chain.message_id, content
|
||||
)
|
||||
await self.bot.set_message(from_user, message_source.message_chain.message_id, content)
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
pass
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: type,
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
|
||||
):
|
||||
async def on_message(event: OAEvent):
|
||||
self.bot_account_id = event.receiver_id
|
||||
try:
|
||||
return await callback(
|
||||
await self.event_converter.target2yiri(event), self
|
||||
)
|
||||
return await callback(await self.event_converter.target2yiri(event), self)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -161,8 +147,6 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
|
||||
async def unregister_listener(
|
||||
self,
|
||||
event_type: type,
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
|
||||
):
|
||||
return super().unregister_listener(event_type, callback)
|
||||
|
||||
@@ -147,9 +147,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
elif type(message_chain) is str:
|
||||
msg_list = [platform_message.Plain(text=message_chain)]
|
||||
else:
|
||||
raise Exception(
|
||||
'Unknown message type: ' + str(message_chain) + str(type(message_chain))
|
||||
)
|
||||
raise Exception('Unknown message type: ' + str(message_chain) + str(type(message_chain)))
|
||||
|
||||
offcial_messages: list[dict] = []
|
||||
"""
|
||||
@@ -172,19 +170,13 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
if component.url is not None:
|
||||
offcial_messages.append({'type': 'image', 'content': component.url})
|
||||
elif component.path is not None:
|
||||
offcial_messages.append(
|
||||
{'type': 'file_image', 'content': component.path}
|
||||
)
|
||||
offcial_messages.append({'type': 'file_image', 'content': component.path})
|
||||
elif type(component) is platform_message.At:
|
||||
offcial_messages.append({'type': 'at', 'content': ''})
|
||||
elif type(component) is platform_message.AtAll:
|
||||
print(
|
||||
'上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。'
|
||||
)
|
||||
print('上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。')
|
||||
elif type(component) is platform_message.Voice:
|
||||
print(
|
||||
'上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。'
|
||||
)
|
||||
print('上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。')
|
||||
elif type(component) is forward.Forward:
|
||||
# 转发消息
|
||||
yiri_forward_node_list = component.node_list
|
||||
@@ -195,9 +187,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
message_chain = yiri_forward_node.message_chain
|
||||
|
||||
# 平铺
|
||||
offcial_messages.extend(
|
||||
OfficialMessageConverter.yiri2target(message_chain)
|
||||
)
|
||||
offcial_messages.extend(OfficialMessageConverter.yiri2target(message_chain))
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
@@ -219,11 +209,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
yiri_msg_list = []
|
||||
# 存id
|
||||
|
||||
yiri_msg_list.append(
|
||||
platform_message.Source(
|
||||
id=save_msg_id(message_id), time=datetime.datetime.now()
|
||||
)
|
||||
)
|
||||
yiri_msg_list.append(platform_message.Source(id=save_msg_id(message_id), time=datetime.datetime.now()))
|
||||
|
||||
if type(message) not in [botpy_message.DirectMessage, botpy_message.C2CMessage]:
|
||||
yiri_msg_list.append(platform_message.At(target=bot_account_id))
|
||||
@@ -239,9 +225,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
if attachment.content_type.startswith('image'):
|
||||
yiri_msg_list.append(platform_message.Image(url=attachment.url))
|
||||
else:
|
||||
logging.warning(
|
||||
'不支持的附件类型:' + attachment.content_type + ',忽略此附件。'
|
||||
)
|
||||
logging.warning('不支持的附件类型:' + attachment.content_type + ',忽略此附件。')
|
||||
|
||||
content = re.sub(r'<@!\d+>', '', str(message.content))
|
||||
if content.strip() != '':
|
||||
@@ -264,9 +248,7 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
||||
elif event == platform_events.FriendMessage:
|
||||
return botpy_message.DirectMessage
|
||||
else:
|
||||
raise Exception(
|
||||
'未支持转换的事件类型(YiriMirai -> Official): ' + str(event)
|
||||
)
|
||||
raise Exception('未支持转换的事件类型(YiriMirai -> Official): ' + str(event))
|
||||
|
||||
def target2yiri(
|
||||
self,
|
||||
@@ -297,21 +279,13 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=int(
|
||||
datetime.datetime.strptime(
|
||||
event.member.joined_at, '%Y-%m-%dT%H:%M:%S%z'
|
||||
).timestamp()
|
||||
datetime.datetime.strptime(event.member.joined_at, '%Y-%m-%dT%H:%M:%S%z').timestamp()
|
||||
),
|
||||
last_speak_timestamp=datetime.datetime.now().timestamp(),
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
|
||||
event, event.id
|
||||
),
|
||||
time=int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
|
||||
).timestamp()
|
||||
),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
|
||||
time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()),
|
||||
)
|
||||
elif isinstance(event, botpy_message.DirectMessage): # 频道私聊,转私聊事件
|
||||
return platform_events.FriendMessage(
|
||||
@@ -320,14 +294,8 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
||||
nickname=event.author.username,
|
||||
remark=event.author.username,
|
||||
),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
|
||||
event, event.id
|
||||
),
|
||||
time=int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
|
||||
).timestamp()
|
||||
),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
|
||||
time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()),
|
||||
)
|
||||
elif isinstance(event, botpy_message.GroupMessage): # 群聊,转群聊事件
|
||||
author_member_id = event.author.member_openid
|
||||
@@ -347,14 +315,8 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
||||
last_speak_timestamp=datetime.datetime.now().timestamp(),
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
|
||||
event, event.id
|
||||
),
|
||||
time=int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
|
||||
).timestamp()
|
||||
),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
|
||||
time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()),
|
||||
)
|
||||
elif isinstance(event, botpy_message.C2CMessage): # 私聊,转私聊事件
|
||||
user_id_alter = event.author.user_openid
|
||||
@@ -365,14 +327,8 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
||||
nickname=user_id_alter,
|
||||
remark=user_id_alter,
|
||||
),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
|
||||
event, event.id
|
||||
),
|
||||
time=int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
|
||||
).timestamp()
|
||||
),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
|
||||
time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()),
|
||||
)
|
||||
|
||||
|
||||
@@ -420,9 +376,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
||||
|
||||
self.bot = botpy.Client(intents=intents)
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
message_list = self.message_converter.yiri2target(message)
|
||||
|
||||
for msg in message_list:
|
||||
@@ -468,22 +422,16 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
||||
|
||||
if quote_origin:
|
||||
args['message_reference'] = botpy_message_type.Reference(
|
||||
message_id=cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
message_id=cached_message_ids[str(message_source.message_chain.message_id)]
|
||||
)
|
||||
|
||||
if isinstance(message_source, platform_events.GroupMessage):
|
||||
args['channel_id'] = str(message_source.sender.group.id)
|
||||
args['msg_id'] = cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
|
||||
await self.bot.api.post_message(**args)
|
||||
elif isinstance(message_source, platform_events.FriendMessage):
|
||||
args['guild_id'] = str(message_source.sender.id)
|
||||
args['msg_id'] = cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
|
||||
await self.bot.api.post_dms(**args)
|
||||
elif isinstance(message_source, OfficialGroupMessage):
|
||||
if 'file_image' in args: # 暂不支持发送文件图片
|
||||
@@ -502,9 +450,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
||||
args['media'] = uploadMedia
|
||||
args['msg_type'] = 7
|
||||
|
||||
args['msg_id'] = cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
|
||||
args['msg_seq'] = self.group_msg_seq
|
||||
self.group_msg_seq += 1
|
||||
|
||||
@@ -523,9 +469,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
||||
args['media'] = uploadMedia
|
||||
args['msg_type'] = 7
|
||||
|
||||
args['msg_id'] = cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
|
||||
|
||||
args['msg_seq'] = self.c2c_msg_seq
|
||||
self.c2c_msg_seq += 1
|
||||
@@ -538,9 +482,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter_model.MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
|
||||
):
|
||||
try:
|
||||
|
||||
@@ -563,9 +505,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter_model.MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
|
||||
):
|
||||
delattr(self.bot, event_handler_mapping[event_type])
|
||||
|
||||
|
||||
@@ -35,13 +35,9 @@ class QQOfficialMessageConverter(adapter.MessageConverter):
|
||||
@staticmethod
|
||||
async def target2yiri(message: str, message_id: str, pic_url: str, content_type):
|
||||
yiri_msg_list = []
|
||||
yiri_msg_list.append(
|
||||
platform_message.Source(id=message_id, time=datetime.datetime.now())
|
||||
)
|
||||
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
|
||||
if pic_url is not None:
|
||||
base64_url = await image.get_qq_official_image_base64(
|
||||
pic_url=pic_url, content_type=content_type
|
||||
)
|
||||
base64_url = await image.get_qq_official_image_base64(pic_url=pic_url, content_type=content_type)
|
||||
yiri_msg_list.append(platform_message.Image(base64=base64_url))
|
||||
|
||||
yiri_msg_list.append(platform_message.Plain(text=message))
|
||||
@@ -75,11 +71,7 @@ class QQOfficialEventConverter(adapter.EventConverter):
|
||||
return platform_events.FriendMessage(
|
||||
sender=friend,
|
||||
message_chain=yiri_chain,
|
||||
time=int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
|
||||
).timestamp()
|
||||
),
|
||||
time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()),
|
||||
source_platform_object=event,
|
||||
)
|
||||
|
||||
@@ -89,9 +81,7 @@ class QQOfficialEventConverter(adapter.EventConverter):
|
||||
nickname=event.t,
|
||||
remark='',
|
||||
)
|
||||
return platform_events.FriendMessage(
|
||||
sender=friend, message_chain=yiri_chain, source_platform_object=event
|
||||
)
|
||||
return platform_events.FriendMessage(sender=friend, message_chain=yiri_chain, source_platform_object=event)
|
||||
if event.t == 'GROUP_AT_MESSAGE_CREATE':
|
||||
yiri_chain.insert(0, platform_message.At(target='justbot'))
|
||||
|
||||
@@ -109,11 +99,7 @@ class QQOfficialEventConverter(adapter.EventConverter):
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
)
|
||||
time = int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
|
||||
).timestamp()
|
||||
)
|
||||
time = int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp())
|
||||
return platform_events.GroupMessage(
|
||||
sender=sender,
|
||||
message_chain=yiri_chain,
|
||||
@@ -136,11 +122,7 @@ class QQOfficialEventConverter(adapter.EventConverter):
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
)
|
||||
time = int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
|
||||
).timestamp()
|
||||
)
|
||||
time = int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp())
|
||||
return platform_events.GroupMessage(
|
||||
sender=sender,
|
||||
message_chain=yiri_chain,
|
||||
@@ -167,9 +149,7 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
|
||||
]
|
||||
missing_keys = [key for key in required_keys if key not in config]
|
||||
if missing_keys:
|
||||
raise ParamNotEnoughError(
|
||||
'QQ官方机器人缺少相关配置项,请查看文档或联系管理员'
|
||||
)
|
||||
raise ParamNotEnoughError('QQ官方机器人缺少相关配置项,请查看文档或联系管理员')
|
||||
|
||||
self.bot = QQOfficialClient(
|
||||
app_id=config['appid'],
|
||||
@@ -229,24 +209,18 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
|
||||
qq_official_event.d_id,
|
||||
)
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
pass
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
):
|
||||
async def on_message(event: QQOfficialEvent):
|
||||
self.bot_account_id = 'justbot'
|
||||
try:
|
||||
return await callback(
|
||||
await self.event_converter.target2yiri(event), self
|
||||
)
|
||||
return await callback(await self.event_converter.target2yiri(event), self)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -274,8 +248,6 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: type,
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, MessagePlatformAdapter], None
|
||||
],
|
||||
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
|
||||
):
|
||||
return super().unregister_listener(event_type, callback)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user