style: restrict line-length

This commit is contained in:
Junyan Qin
2025-05-10 18:04:58 +08:00
parent b30016ed08
commit 055b389353
134 changed files with 1096 additions and 2595 deletions

View File

@@ -1,3 +1,6 @@
line-length = 120
[lint]
ignore = [

View File

@@ -25,9 +25,7 @@ class DingTalkClient:
self.secret = client_secret
# 在 DingTalkClient 中传入自己作为参数,避免循环导入
self.EchoTextHandler = EchoTextHandler(self)
self.client.register_callback_handler(
dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler
)
self.client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler)
self._message_handlers = {
'example': [],
}
@@ -86,9 +84,7 @@ class DingTalkClient:
if response.status_code == 200:
file_bytes = response.content
base64_str = base64.b64encode(file_bytes).decode(
'utf-8'
) # 返回字符串格式
base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式
return base64_str
else:
raise Exception('获取文件失败')
@@ -151,9 +147,7 @@ class DingTalkClient:
for handler in self._message_handlers[msg_type]:
await handler(event)
async def get_message(
self, incoming_message: dingtalk_stream.chatbot.ChatbotMessage
):
async def get_message(self, incoming_message: dingtalk_stream.chatbot.ChatbotMessage):
try:
# print(json.dumps(incoming_message.to_dict(), indent=4, ensure_ascii=False))
message_data = {
@@ -170,9 +164,7 @@ class DingTalkClient:
if 'text' in item:
message_data['Content'] = item['text']
if incoming_message.get_image_list()[0]:
message_data['Picture'] = await self.download_image(
incoming_message.get_image_list()[0]
)
message_data['Picture'] = await self.download_image(incoming_message.get_image_list()[0])
message_data['Type'] = 'text'
elif incoming_message.message_type == 'text':
@@ -180,15 +172,11 @@ class DingTalkClient:
message_data['Type'] = 'text'
elif incoming_message.message_type == 'picture':
message_data['Picture'] = await self.download_image(
incoming_message.get_image_list()[0]
)
message_data['Picture'] = await self.download_image(incoming_message.get_image_list()[0])
message_data['Type'] = 'image'
elif incoming_message.message_type == 'audio':
message_data['Audio'] = await self.get_audio_url(
incoming_message.to_dict()['content']['downloadCode']
)
message_data['Audio'] = await self.get_audio_url(incoming_message.to_dict()['content']['downloadCode'])
message_data['Type'] = 'audio'

View File

@@ -68,9 +68,7 @@ class OAClient:
elif request.method == 'POST':
encryt_msg = await request.data
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
ret, xml_msg = wxcpt.DecryptMsg(
encryt_msg, msg_signature, timestamp, nonce
)
ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce)
xml_msg = xml_msg.decode('utf-8')
if ret != 0:
@@ -112,9 +110,7 @@ class OAClient:
# create_time=int(time.time()),
# content = "请求失效暂不支持公众号超过15秒的请求如有需求请联系 LangBot 团队。"
# )
print(
'请求失效暂不支持公众号超过15秒的请求如有需求请联系 LangBot 团队。'
)
print('请求失效暂不支持公众号超过15秒的请求如有需求请联系 LangBot 团队。')
return ''
except Exception:
@@ -128,12 +124,8 @@ class OAClient:
'FromUserName': root.find('FromUserName').text,
'CreateTime': int(root.find('CreateTime').text),
'MsgType': root.find('MsgType').text,
'Content': root.find('Content').text
if root.find('Content') is not None
else None,
'MsgId': int(root.find('MsgId').text)
if root.find('MsgId') is not None
else None,
'Content': root.find('Content').text if root.find('Content') is not None else None,
'MsgId': int(root.find('MsgId').text) if root.find('MsgId') is not None else None,
}
return message_data
@@ -225,9 +217,7 @@ class OAClientForLongerResponse:
elif request.method == 'POST':
encryt_msg = await request.data
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
ret, xml_msg = wxcpt.DecryptMsg(
encryt_msg, msg_signature, timestamp, nonce
)
ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce)
xml_msg = xml_msg.decode('utf-8')
if ret != 0:
@@ -238,18 +228,12 @@ class OAClientForLongerResponse:
from_user = root.find('FromUserName').text
to_user = root.find('ToUserName').text
if (
self.msg_queue.get(from_user)
and self.msg_queue[from_user][0]['content']
):
if self.msg_queue.get(from_user) and self.msg_queue[from_user][0]['content']:
queue_top = self.msg_queue[from_user].pop(0)
queue_content = queue_top['content']
# 弹出用户消息
if (
self.user_msg_queue.get(from_user)
and self.user_msg_queue[from_user]
):
if self.user_msg_queue.get(from_user) and self.user_msg_queue[from_user]:
self.user_msg_queue[from_user].pop(0)
response_xml = xml_template.format(
@@ -268,10 +252,7 @@ class OAClientForLongerResponse:
content=self.loading_message,
)
if (
self.user_msg_queue.get(from_user)
and self.user_msg_queue[from_user][0]['content']
):
if self.user_msg_queue.get(from_user) and self.user_msg_queue[from_user][0]['content']:
return response_xml
else:
message_data = await self.get_message(xml_msg)
@@ -299,12 +280,8 @@ class OAClientForLongerResponse:
'FromUserName': root.find('FromUserName').text,
'CreateTime': int(root.find('CreateTime').text),
'MsgType': root.find('MsgType').text,
'Content': root.find('Content').text
if root.find('Content') is not None
else None,
'MsgId': int(root.find('MsgId').text)
if root.find('MsgId') is not None
else None,
'Content': root.find('Content').text if root.find('Content') is not None else None,
'MsgId': int(root.find('MsgId').text) if root.find('MsgId') is not None else None,
}
return message_data

View File

@@ -144,15 +144,9 @@ class QQOfficialClient:
'group_openid': msg.get('d', {}).get('group_openid', {}),
}
attachments = msg.get('d', {}).get('attachments', [])
image_attachments = [
attachment['url']
for attachment in attachments
if await self.is_image(attachment)
]
image_attachments = [attachment['url'] for attachment in attachments if await self.is_image(attachment)]
image_attachments_type = [
attachment['content_type']
for attachment in attachments
if await self.is_image(attachment)
attachment['content_type'] for attachment in attachments if await self.is_image(attachment)
]
if image_attachments:
message_data['image_attachments'] = image_attachments[0]
@@ -211,9 +205,7 @@ class QQOfficialClient:
else:
raise Exception(response.read().decode())
async def send_channle_group_text_msg(
self, channel_id: str, content: str, msg_id: str
):
async def send_channle_group_text_msg(self, channel_id: str, content: str, msg_id: str):
"""发送频道群聊消息"""
if not await self.check_access_token():
await self.get_access_token()
@@ -235,9 +227,7 @@ class QQOfficialClient:
else:
raise Exception(response)
async def send_channle_private_text_msg(
self, guild_id: str, content: str, msg_id: str
):
async def send_channle_private_text_msg(self, guild_id: str, content: str, msg_id: str):
"""发送频道私聊消息"""
if not await self.check_access_token():
await self.get_access_token()

View File

@@ -2,20 +2,21 @@ import json
from quart import Quart, jsonify, request
from slack_sdk.web.async_client import AsyncWebClient
from .slackevent import SlackEvent
from typing import Callable, Dict, Any
from pkg.platform.types import events as platform_events, message as platform_message
from typing import Callable
from pkg.platform.types import events as platform_events
class SlackClient():
class SlackClient:
def __init__(self, bot_token: str, signing_secret: str):
self.bot_token = bot_token
self.signing_secret = signing_secret
self.app = Quart(__name__)
self.client = AsyncWebClient(self.bot_token)
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'])
self.app.add_url_rule(
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']
)
self._message_handlers = {
"example":[],
'example': [],
}
self.bot_user_id = None # 避免机器人回复自己的消息
@@ -27,21 +28,20 @@ class SlackClient():
if data['type'] == 'url_verification':
return data['challenge']
bot_user_id = data.get("event",{}).get("bot_id","")
bot_user_id = data.get('event', {}).get('bot_id', '')
if self.bot_user_id and bot_user_id == self.bot_user_id:
return jsonify({'status': 'ok'})
# 处理私信
if data and data.get("event", {}).get("channel_type") in ["im"]:
if data and data.get('event', {}).get('channel_type') in ['im']:
event = SlackEvent.from_payload(data)
await self._handle_message(event)
return jsonify({'status': 'ok'})
# 处理群聊
if data.get("event",{}).get("type") == 'app_mention':
data.setdefault("event", {})["channel_type"] = "channel"
if data.get('event', {}).get('type') == 'app_mention':
data.setdefault('event', {})['channel_type'] = 'channel'
event = SlackEvent.from_payload(data)
await self._handle_message(event)
return jsonify({'status': 'ok'})
@@ -51,8 +51,6 @@ class SlackClient():
except Exception as e:
raise (e)
async def _handle_message(self, event: SlackEvent):
"""
处理消息事件。
@@ -64,33 +62,29 @@ class SlackClient():
def on_message(self, msg_type: str):
"""注册消息类型处理器"""
def decorator(func: Callable[[platform_events.Event], None]):
if msg_type not in self._message_handlers:
self._message_handlers[msg_type] = []
self._message_handlers[msg_type].append(func)
return func
return decorator
async def send_message_to_channel(self, text: str, channel_id: str):
try:
response = await self.client.chat_postMessage(
channel=channel_id,
text=text
)
if self.bot_user_id is None and response.get("ok"):
self.bot_user_id = response["message"]["bot_id"]
response = await self.client.chat_postMessage(channel=channel_id, text=text)
if self.bot_user_id is None and response.get('ok'):
self.bot_user_id = response['message']['bot_id']
return
except Exception as e:
raise e
async def send_message_to_one(self, text: str, user_id: str):
try:
response = await self.client.chat_postMessage(
channel = '@'+user_id,
text= text
)
if self.bot_user_id is None and response.get("ok"):
self.bot_user_id = response["message"]["bot_id"]
response = await self.client.chat_postMessage(channel='@' + user_id, text=text)
if self.bot_user_id is None and response.get('ok'):
self.bot_user_id = response['message']['bot_id']
return
except Exception as e:
@@ -101,11 +95,3 @@ class SlackClient():
启动 Quart 应用。
"""
await self.app.run_task(host=host, port=port, *args, **kwargs)

View File

@@ -1,8 +1,9 @@
from typing import Dict, Any, Optional
class SlackEvent(dict):
@staticmethod
def from_payload(payload: Dict[str, Any]) -> Optional["SlackEvent"]:
def from_payload(payload: Dict[str, Any]) -> Optional['SlackEvent']:
try:
event = SlackEvent(payload)
return event
@@ -11,75 +12,70 @@ class SlackEvent(dict):
@property
def text(self) -> str:
if self.get("event", {}).get("channel_type") == "im":
blocks = self.get("event", {}).get("blocks", [])
if self.get('event', {}).get('channel_type') == 'im':
blocks = self.get('event', {}).get('blocks', [])
if not blocks:
return ""
return ''
elements = blocks[0].get("elements", [])
elements = blocks[0].get('elements', [])
if not elements:
return ""
return ''
elements = elements[0].get("elements", [])
text = ""
elements = elements[0].get('elements', [])
text = ''
for el in elements:
if el.get("type") == "text":
text += el.get("text", "")
elif el.get("type") == "link":
text += el.get("url", "")
if el.get('type') == 'text':
text += el.get('text', '')
elif el.get('type') == 'link':
text += el.get('url', '')
return text
if self.get("event",{}).get("channel_type") == 'channel':
message_text = ""
for block in self.get("event", {}).get("blocks", []):
if block.get("type") == "rich_text":
for element in block.get("elements", []):
if element.get("type") == "rich_text_section":
if self.get('event', {}).get('channel_type') == 'channel':
message_text = ''
for block in self.get('event', {}).get('blocks', []):
if block.get('type') == 'rich_text':
for element in block.get('elements', []):
if element.get('type') == 'rich_text_section':
parts = []
for el in element.get("elements", []):
if el.get("type") == "text":
parts.append(el["text"])
elif el.get("type") == "link":
parts.append(el["url"])
message_text = "".join(parts)
for el in element.get('elements', []):
if el.get('type') == 'text':
parts.append(el['text'])
elif el.get('type') == 'link':
parts.append(el['url'])
message_text = ''.join(parts)
return message_text
@property
def user_id(self) -> Optional[str]:
return self.get("event", {}).get("user","")
return self.get('event', {}).get('user', '')
@property
def channel_id(self) -> Optional[str]:
return self.get("event", {}).get("channel","")
return self.get('event', {}).get('channel', '')
@property
def type(self) -> str:
"""message对应私聊app_mention对应频道at"""
return self.get("event", {}).get("channel_type", "")
return self.get('event', {}).get('channel_type', '')
@property
def message_id(self) -> str:
return self.get("event_id","")
return self.get('event_id', '')
@property
def pic_url(self) -> str:
"""提取 Slack 事件中的图片 URL"""
files = self.get("event", {}).get("files", [])
files = self.get('event', {}).get('files', [])
if files:
return files[0].get("url_private", "")
return files[0].get('url_private', '')
return None
@property
def sender_name(self) -> str:
return self.get("event", {}).get("user","")
return self.get('event', {}).get('user', '')
def __getattr__(self, key: str) -> Optional[Any]:
return self.get(key)
@@ -88,4 +84,4 @@ class SlackEvent(dict):
self[key] = value
def __repr__(self) -> str:
return f"<SlackEvent {super().__repr__()}>"
return f'<SlackEvent {super().__repr__()}>'

View File

@@ -147,12 +147,7 @@ class Prpcrypt(object):
"""
# 16位随机字符串添加到明文开头
text = text.encode()
text = (
self.get_random_str()
+ struct.pack('I', socket.htonl(len(text)))
+ text
+ receiveid.encode()
)
text = self.get_random_str() + struct.pack('I', socket.htonl(len(text))) + text + receiveid.encode()
# 使用自定义的填充方式对明文进行补位填充
pkcs7 = PKCS7Encoder()

View File

@@ -45,9 +45,7 @@ class WecomClient:
return bool(self.access_token and self.access_token.strip())
async def check_access_token_for_contacts(self):
return bool(
self.access_token_for_contacts and self.access_token_for_contacts.strip()
)
return bool(self.access_token_for_contacts and self.access_token_for_contacts.strip())
async def get_access_token(self, secret):
url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corpid}&corpsecret={secret}'
@@ -61,15 +59,9 @@ class WecomClient:
async def get_users(self):
if not self.check_access_token_for_contacts():
self.access_token_for_contacts = await self.get_access_token(
self.secret_for_contacts
)
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
url = (
self.base_url
+ '/user/list_id?access_token='
+ self.access_token_for_contacts
)
url = self.base_url + '/user/list_id?access_token=' + self.access_token_for_contacts
async with httpx.AsyncClient() as client:
params = {
'cursor': '',
@@ -88,15 +80,9 @@ class WecomClient:
async def send_to_all(self, content: str, agent_id: int):
if not self.check_access_token_for_contacts():
self.access_token_for_contacts = await self.get_access_token(
self.secret_for_contacts
)
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
url = (
self.base_url
+ '/message/send?access_token='
+ self.access_token_for_contacts
)
url = self.base_url + '/message/send?access_token=' + self.access_token_for_contacts
user_ids = await self.get_users()
user_ids_string = '|'.join(user_ids)
async with httpx.AsyncClient() as client:
@@ -187,27 +173,21 @@ class WecomClient:
if request.method == 'GET':
echostr = request.args.get('echostr')
ret, reply_echo_str = self.wxcpt.VerifyURL(
msg_signature, timestamp, nonce, echostr
)
ret, reply_echo_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
if ret != 0:
raise Exception(f'验证失败,错误码: {ret}')
return reply_echo_str
elif request.method == 'POST':
encrypt_msg = await request.data
ret, xml_msg = self.wxcpt.DecryptMsg(
encrypt_msg, msg_signature, timestamp, nonce
)
ret, xml_msg = self.wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
if ret != 0:
raise Exception(f'消息解密失败,错误码: {ret}')
# 解析消息并处理
message_data = await self.get_message(xml_msg)
if message_data:
event = WecomEvent.from_payload(
message_data
) # 转换为 WecomEvent 对象
event = WecomEvent.from_payload(message_data) # 转换为 WecomEvent 对象
if event:
await self._handle_message(event)
@@ -253,23 +233,13 @@ class WecomClient:
'FromUserName': root.find('FromUserName').text,
'CreateTime': int(root.find('CreateTime').text),
'MsgType': root.find('MsgType').text,
'Content': root.find('Content').text
if root.find('Content') is not None
else None,
'MsgId': int(root.find('MsgId').text)
if root.find('MsgId') is not None
else None,
'AgentID': int(root.find('AgentID').text)
if root.find('AgentID') is not None
else None,
'Content': root.find('Content').text if root.find('Content') is not None else None,
'MsgId': int(root.find('MsgId').text) if root.find('MsgId') is not None else None,
'AgentID': int(root.find('AgentID').text) if root.find('AgentID') is not None else None,
}
if message_data['MsgType'] == 'image':
message_data['MediaId'] = (
root.find('MediaId').text if root.find('MediaId') is not None else None
)
message_data['PicUrl'] = (
root.find('PicUrl').text if root.find('PicUrl') is not None else None
)
message_data['MediaId'] = root.find('MediaId').text if root.find('MediaId') is not None else None
message_data['PicUrl'] = root.find('PicUrl').text if root.find('PicUrl') is not None else None
return message_data
@@ -298,12 +268,7 @@ class WecomClient:
if not await self.check_access_token():
self.access_token = await self.get_access_token(self.secret)
url = (
self.base_url
+ '/media/upload?access_token='
+ self.access_token
+ '&type=file'
)
url = self.base_url + '/media/upload?access_token=' + self.access_token + '&type=file'
file_bytes = None
file_name = 'uploaded_file.txt'

View File

@@ -6,13 +6,13 @@ import httpx
import traceback
from quart import Quart
import xml.etree.ElementTree as ET
from typing import Callable, Dict, Any
from typing import Callable
from .wecomcsevent import WecomCSEvent
from pkg.platform.types import events as platform_events, message as platform_message
from pkg.platform.types import message as platform_message
import aiofiles
class WecomCSClient():
class WecomCSClient:
def __init__(self, corpid: str, secret: str, token: str, EncodingAESKey: str):
self.corpid = corpid
self.secret = secret
@@ -23,35 +23,36 @@ class WecomCSClient():
self.access_token = ''
self.app = Quart(__name__)
self.wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid)
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'])
self.app.add_url_rule(
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']
)
self._message_handlers = {
"example":[],
'example': [],
}
async def get_pic_url(self, media_id: str):
if not await self.check_access_token():
self.access_token = await self.get_access_token(self.secret)
url = f"{self.base_url}/media/get?access_token={self.access_token}&media_id={media_id}"
url = f'{self.base_url}/media/get?access_token={self.access_token}&media_id={media_id}'
async with httpx.AsyncClient() as client:
response = await client.get(url)
if response.headers.get("Content-Type", "").startswith("application/json"):
if response.headers.get('Content-Type', '').startswith('application/json'):
data = response.json()
if data.get('errcode') in [40014, 42001]:
self.access_token = await self.get_access_token(self.secret)
return await self.get_pic_url(media_id)
else:
raise Exception("Failed to get image: " + str(data))
raise Exception('Failed to get image: ' + str(data))
# 否则是图片,转成 base64
image_bytes = response.content
content_type = response.headers.get("Content-Type", "")
base64_str = base64.b64encode(image_bytes).decode("utf-8")
base64_str = f"data:{content_type};base64,{base64_str}"
content_type = response.headers.get('Content-Type', '')
base64_str = base64.b64encode(image_bytes).decode('utf-8')
base64_str = f'data:{content_type};base64,{base64_str}'
return base64_str
# access——token操作
async def check_access_token(self):
return bool(self.access_token and self.access_token.strip())
@@ -67,15 +68,15 @@ class WecomCSClient():
if 'access_token' in data:
return data['access_token']
else:
raise Exception(f"未获取access token: {data}")
raise Exception(f'未获取access token: {data}')
async def get_detailed_message_list(self, xml_msg: str):
# 在本方法中解析消息,并且获得消息的具体内容
if isinstance(xml_msg, bytes):
xml_msg = xml_msg.decode('utf-8')
root = ET.fromstring(xml_msg)
token = root.find("Token").text
open_kfid = root.find("OpenKfId").text
token = root.find('Token').text
open_kfid = root.find('OpenKfId').text
# if open_kfid in self.openkfid_list:
# return None
@@ -88,9 +89,9 @@ class WecomCSClient():
url = self.base_url + '/kf/sync_msg?access_token=' + self.access_token
async with httpx.AsyncClient() as client:
params = {
"token": token,
"voice_format": 0,
"open_kfid": open_kfid,
'token': token,
'voice_format': 0,
'open_kfid': open_kfid,
}
response = await client.post(url, json=params)
data = response.json()
@@ -98,29 +99,28 @@ class WecomCSClient():
self.access_token = await self.get_access_token(self.secret)
return await self.get_detailed_message_list(xml_msg)
if data['errcode'] != 0:
raise Exception("Failed to get message")
raise Exception('Failed to get message')
last_msg_data = data['msg_list'][-1]
open_kfid = last_msg_data.get("open_kfid")
open_kfid = last_msg_data.get('open_kfid')
# 进行获取图片操作
if last_msg_data.get("msgtype") == "image":
media_id = last_msg_data.get("image").get("media_id")
if last_msg_data.get('msgtype') == 'image':
media_id = last_msg_data.get('image').get('media_id')
picurl = await self.get_pic_url(media_id)
last_msg_data["picurl"] = picurl
last_msg_data['picurl'] = picurl
# await self.change_service_status(userid=external_userid,openkfid=open_kfid,servicer=servicer)
return last_msg_data
async def change_service_status(self, userid: str, openkfid: str, servicer: str):
if not await self.check_access_token():
self.access_token = await self.get_access_token(self.secret)
url = self.base_url+"/kf/service_state/get?access_token="+self.access_token
url = self.base_url + '/kf/service_state/get?access_token=' + self.access_token
async with httpx.AsyncClient() as client:
params = {
"open_kfid" : openkfid,
"external_userid" : userid,
"service_state" : 1,
"servicer_userid" : servicer,
'open_kfid': openkfid,
'external_userid': userid,
'service_state': 1,
'servicer_userid': servicer,
}
response = await client.post(url, json=params)
data = response.json()
@@ -128,8 +128,7 @@ class WecomCSClient():
self.access_token = await self.get_access_token(self.secret)
return await self.change_service_status(userid, openkfid)
if data['errcode'] != 0:
raise Exception("Failed to change service status: "+str(data))
raise Exception('Failed to change service status: ' + str(data))
async def send_image(self, user_id: str, agent_id: int, media_id: str):
if not await self.check_access_token():
@@ -137,24 +136,24 @@ class WecomCSClient():
url = self.base_url + '/media/upload?access_token=' + self.access_token
async with httpx.AsyncClient() as client:
params = {
"touser" : user_id,
"toparty" : "",
"totag":"",
"agentid" : agent_id,
"msgtype" : "image",
"image" : {
"media_id" : media_id,
'touser': user_id,
'toparty': '',
'totag': '',
'agentid': agent_id,
'msgtype': 'image',
'image': {
'media_id': media_id,
},
"safe":0,
"enable_id_trans": 0,
"enable_duplicate_check": 0,
"duplicate_check_interval": 1800
'safe': 0,
'enable_id_trans': 0,
'enable_duplicate_check': 0,
'duplicate_check_interval': 1800,
}
try:
response = await client.post(url, json=params)
data = response.json()
except Exception as e:
raise Exception("Failed to send image: "+str(e))
raise Exception('Failed to send image: ' + str(e))
# 企业微信错误码40014和42001代表accesstoken问题
if data['errcode'] == 40014 or data['errcode'] == 42001:
@@ -162,23 +161,22 @@ class WecomCSClient():
return await self.send_image(user_id, agent_id, media_id)
if data['errcode'] != 0:
raise Exception("Failed to send image: "+str(data))
raise Exception('Failed to send image: ' + str(data))
async def send_text_msg(self, open_kfid: str, external_userid: str, msgid: str, content: str):
if not await self.check_access_token():
self.access_token = await self.get_access_token(self.secret)
url = f"https://qyapi.weixin.qq.com/cgi-bin/kf/send_msg?access_token={self.access_token}"
url = f'https://qyapi.weixin.qq.com/cgi-bin/kf/send_msg?access_token={self.access_token}'
payload = {
"touser": external_userid,
"open_kfid": open_kfid,
"msgid": msgid,
"msgtype": "text",
"text": {
"content": content,
}
'touser': external_userid,
'open_kfid': open_kfid,
'msgid': msgid,
'msgtype': 'text',
'text': {
'content': content,
},
}
async with httpx.AsyncClient() as client:
@@ -189,32 +187,30 @@ class WecomCSClient():
self.access_token = await self.get_access_token(self.secret)
return await self.send_text_msg(open_kfid, external_userid, msgid, content)
if data['errcode'] != 0:
raise Exception("Failed to send message")
raise Exception('Failed to send message')
return data
async def handle_callback_request(self):
"""
处理回调请求,包括 GET 验证和 POST 消息接收。
"""
try:
msg_signature = request.args.get('msg_signature')
timestamp = request.args.get('timestamp')
nonce = request.args.get('nonce')
msg_signature = request.args.get("msg_signature")
timestamp = request.args.get("timestamp")
nonce = request.args.get("nonce")
if request.method == "GET":
echostr = request.args.get("echostr")
if request.method == 'GET':
echostr = request.args.get('echostr')
ret, reply_echo_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
if ret != 0:
raise Exception(f"验证失败,错误码: {ret}")
raise Exception(f'验证失败,错误码: {ret}')
return reply_echo_str
elif request.method == "POST":
elif request.method == 'POST':
encrypt_msg = await request.data
ret, xml_msg = self.wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
if ret != 0:
raise Exception(f"消息解密失败,错误码: {ret}")
raise Exception(f'消息解密失败,错误码: {ret}')
# 解析消息并处理
message_data = await self.get_detailed_message_list(xml_msg)
@@ -223,10 +219,10 @@ class WecomCSClient():
if event:
await self._handle_message(event)
return "success"
return 'success'
except Exception as e:
traceback.print_exc()
return f"Error processing request: {str(e)}", 400
return f'Error processing request: {str(e)}', 400
async def run_task(self, host: str, port: int, *args, **kwargs):
"""
@@ -238,11 +234,13 @@ class WecomCSClient():
"""
注册消息类型处理器。
"""
def decorator(func: Callable[[WecomCSEvent], None]):
if msg_type not in self._message_handlers:
self._message_handlers[msg_type] = []
self._message_handlers[msg_type].append(func)
return func
return decorator
async def _handle_message(self, event: WecomCSEvent):
@@ -254,18 +252,17 @@ class WecomCSClient():
for handler in self._message_handlers[msg_type]:
await handler(event)
@staticmethod
async def get_image_type(image_bytes: bytes) -> str:
"""
通过图片的magic numbers判断图片类型
"""
magic_numbers = {
b'\xFF\xD8\xFF': 'jpg',
b'\x89\x50\x4E\x47': 'png',
b'\xff\xd8\xff': 'jpg',
b'\x89\x50\x4e\x47': 'png',
b'\x47\x49\x46': 'gif',
b'\x42\x4D': 'bmp',
b'\x00\x00\x01\x00': 'ico'
b'\x42\x4d': 'bmp',
b'\x00\x00\x01\x00': 'ico',
}
for magic, ext in magic_numbers.items():
@@ -273,7 +270,6 @@ class WecomCSClient():
return ext
return 'jpg' # 默认返回jpg
async def upload_to_work(self, image: platform_message.Image):
"""
获取 media_id
@@ -283,7 +279,7 @@ class WecomCSClient():
url = self.base_url + '/media/upload?access_token=' + self.access_token + '&type=file'
file_bytes = None
file_name = "uploaded_file.txt"
file_name = 'uploaded_file.txt'
# 获取文件的二进制数据
if image.path:
@@ -302,20 +298,22 @@ class WecomCSClient():
padded_base64 = base64_data + '=' * padding
file_bytes = base64.b64decode(padded_base64)
except binascii.Error as e:
raise ValueError(f"Invalid base64 string: {str(e)}")
raise ValueError(f'Invalid base64 string: {str(e)}')
else:
raise ValueError("image对象出错")
raise ValueError('image对象出错')
# 设置 multipart/form-data 格式的文件
boundary = "-------------------------acebdf13572468"
headers = {
'Content-Type': f'multipart/form-data; boundary={boundary}'
}
boundary = '-------------------------acebdf13572468'
headers = {'Content-Type': f'multipart/form-data; boundary={boundary}'}
body = (
f"--{boundary}\r\n"
f"Content-Disposition: form-data; name=\"media\"; filename=\"{file_name}\"; filelength={len(file_bytes)}\r\n"
f"Content-Type: application/octet-stream\r\n\r\n"
).encode('utf-8') + file_bytes + f"\r\n--{boundary}--\r\n".encode('utf-8')
(
f'--{boundary}\r\n'
f'Content-Disposition: form-data; name="media"; filename="{file_name}"; filelength={len(file_bytes)}\r\n'
f'Content-Type: application/octet-stream\r\n\r\n'
).encode('utf-8')
+ file_bytes
+ f'\r\n--{boundary}--\r\n'.encode('utf-8')
)
# 上传文件
async with httpx.AsyncClient() as client:
@@ -325,7 +323,7 @@ class WecomCSClient():
self.access_token = await self.get_access_token(self.secret)
media_id = await self.upload_to_work(image)
if data.get('errcode', 0) != 0:
raise Exception("failed to upload file")
raise Exception('failed to upload file')
media_id = data.get('media_id')
return media_id
@@ -338,6 +336,5 @@ class WecomCSClient():
# 进行media_id的获取
async def get_media_id(self, image: platform_message.Image):
media_id = await self.upload_to_work(image=image)
return media_id

View File

@@ -9,7 +9,7 @@ class WecomCSEvent(dict):
"""
@staticmethod
def from_payload(payload: Dict[str, Any]) -> Optional["WecomCSEvent"]:
def from_payload(payload: Dict[str, Any]) -> Optional['WecomCSEvent']:
"""
从企业微信(客服会话)事件数据构造 `WecomEvent` 对象。
@@ -21,7 +21,7 @@ class WecomCSEvent(dict):
"""
try:
event = WecomCSEvent(payload)
_ = event.type,
_ = (event.type,)
return event
except KeyError:
return None
@@ -34,7 +34,7 @@ class WecomCSEvent(dict):
Returns:
str: 事件类型。
"""
return self.get("msgtype", "")
return self.get('msgtype', '')
@property
def user_id(self) -> Optional[str]:
@@ -44,7 +44,7 @@ class WecomCSEvent(dict):
Returns:
Optional[str]: 用户 ID。
"""
return self.get("external_userid")
return self.get('external_userid')
@property
def receiver_id(self) -> Optional[str]:
@@ -54,7 +54,7 @@ class WecomCSEvent(dict):
Returns:
Optional[str]: 接收者 ID。
"""
return self.get("open_kfid","")
return self.get('open_kfid', '')
@property
def picurl(self) -> Optional[str]:
@@ -65,7 +65,7 @@ class WecomCSEvent(dict):
Optional[str]: 图片 URL。
"""
return self.get("picurl","")
return self.get('picurl', '')
@property
def message_id(self) -> Optional[str]:
@@ -75,7 +75,7 @@ class WecomCSEvent(dict):
Returns:
Optional[str]: 消息 ID。
"""
return self.get("msgid")
return self.get('msgid')
@property
def message(self) -> Optional[str]:
@@ -85,12 +85,11 @@ class WecomCSEvent(dict):
Returns:
Optional[str]: 消息内容。
"""
if self.get("msgtype") == 'text':
return self.get("text").get("content","")
if self.get('msgtype') == 'text':
return self.get('text').get('content', '')
else:
return None
@property
def timestamp(self) -> Optional[int]:
"""
@@ -99,8 +98,7 @@ class WecomCSEvent(dict):
Returns:
Optional[int]: 时间戳。
"""
return self.get("send_time")
return self.get('send_time')
def __getattr__(self, key: str) -> Optional[Any]:
"""
@@ -131,4 +129,4 @@ class WecomCSEvent(dict):
Returns:
str: 字符串表示。
"""
return f"<WecomEvent {super().__repr__()}>"
return f'<WecomEvent {super().__repr__()}>'

View File

@@ -65,9 +65,7 @@ class RouterGroup(abc.ABC):
async def handler_error(*args, **kwargs):
if auth_type == AuthType.USER_TOKEN:
# 从Authorization头中获取token
token = quart.request.headers.get('Authorization', '').replace(
'Bearer ', ''
)
token = quart.request.headers.get('Authorization', '').replace('Bearer ', '')
if not token:
return self.http_status(401, -1, '未提供有效的用户令牌')

View File

@@ -14,11 +14,9 @@ class LogsRouterGroup(group.RouterGroup):
start_page_number = int(quart.request.args.get('start_page_number', 0))
start_offset = int(quart.request.args.get('start_offset', 0))
logs_str, end_page_number, end_offset = (
self.ap.log_cache.get_log_by_pointer(
logs_str, end_page_number, end_offset = self.ap.log_cache.get_log_by_pointer(
start_page_number=start_page_number, start_offset=start_offset
)
)
return self.success(
data={

View File

@@ -11,23 +11,17 @@ class PipelinesRouterGroup(group.RouterGroup):
@self.route('', methods=['GET', 'POST'])
async def _() -> str:
if quart.request.method == 'GET':
return self.success(
data={'pipelines': await self.ap.pipeline_service.get_pipelines()}
)
return self.success(data={'pipelines': await self.ap.pipeline_service.get_pipelines()})
elif quart.request.method == 'POST':
json_data = await quart.request.json
pipeline_uuid = await self.ap.pipeline_service.create_pipeline(
json_data
)
pipeline_uuid = await self.ap.pipeline_service.create_pipeline(json_data)
return self.success(data={'uuid': pipeline_uuid})
@self.route('/_/metadata', methods=['GET'])
async def _() -> str:
return self.success(
data={'configs': await self.ap.pipeline_service.get_pipeline_metadata()}
)
return self.success(data={'configs': await self.ap.pipeline_service.get_pipeline_metadata()})
@self.route('/<pipeline_uuid>', methods=['GET', 'PUT', 'DELETE'])
async def _(pipeline_uuid: str) -> str:

View File

@@ -8,30 +8,20 @@ class AdaptersRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'])
async def _() -> str:
return self.success(
data={'adapters': self.ap.platform_mgr.get_available_adapters_info()}
)
return self.success(data={'adapters': self.ap.platform_mgr.get_available_adapters_info()})
@self.route('/<adapter_name>', methods=['GET'])
async def _(adapter_name: str) -> str:
adapter_info = self.ap.platform_mgr.get_available_adapter_info_by_name(
adapter_name
)
adapter_info = self.ap.platform_mgr.get_available_adapter_info_by_name(adapter_name)
if adapter_info is None:
return self.http_status(404, -1, 'adapter not found')
return self.success(data={'adapter': adapter_info})
@self.route(
'/<adapter_name>/icon', methods=['GET'], auth_type=group.AuthType.NONE
)
@self.route('/<adapter_name>/icon', methods=['GET'], auth_type=group.AuthType.NONE)
async def _(adapter_name: str) -> quart.Response:
adapter_manifest = (
self.ap.platform_mgr.get_available_adapter_manifest_by_name(
adapter_name
)
)
adapter_manifest = self.ap.platform_mgr.get_available_adapter_manifest_by_name(adapter_name)
if adapter_manifest is None:
return self.http_status(404, -1, 'adapter not found')

View File

@@ -92,9 +92,7 @@ class PluginsRouterGroup(group.RouterGroup):
await self.ap.plugin_mgr.reorder_plugins(data.get('plugins'))
return self.success()
@self.route(
'/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN
)
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
data = await quart.request.json

View File

@@ -9,9 +9,7 @@ class LLMModelsRouterGroup(group.RouterGroup):
@self.route('', methods=['GET', 'POST'])
async def _() -> str:
if quart.request.method == 'GET':
return self.success(
data={'models': await self.ap.model_service.get_llm_models()}
)
return self.success(data={'models': await self.ap.model_service.get_llm_models()})
elif quart.request.method == 'POST':
json_data = await quart.request.json

View File

@@ -8,30 +8,20 @@ class RequestersRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'])
async def _() -> quart.Response:
return self.success(
data={'requesters': self.ap.model_mgr.get_available_requesters_info()}
)
return self.success(data={'requesters': self.ap.model_mgr.get_available_requesters_info()})
@self.route('/<requester_name>', methods=['GET'])
async def _(requester_name: str) -> quart.Response:
requester_info = self.ap.model_mgr.get_available_requester_info_by_name(
requester_name
)
requester_info = self.ap.model_mgr.get_available_requester_info_by_name(requester_name)
if requester_info is None:
return self.http_status(404, -1, 'requester not found')
return self.success(data={'requester': requester_info})
@self.route(
'/<requester_name>/icon', methods=['GET'], auth_type=group.AuthType.NONE
)
@self.route('/<requester_name>/icon', methods=['GET'], auth_type=group.AuthType.NONE)
async def _(requester_name: str) -> quart.Response:
requester_manifest = (
self.ap.model_mgr.get_available_requester_manifest_by_name(
requester_name
)
)
requester_manifest = self.ap.model_mgr.get_available_requester_manifest_by_name(requester_name)
if requester_manifest is None:
return self.http_status(404, -1, 'requester not found')

View File

@@ -8,9 +8,7 @@ class StatsRouterGroup(group.RouterGroup):
async def _() -> str:
conv_count = 0
for session in self.ap.sess_mgr.session_list:
conv_count += len(
session.conversations if session.conversations is not None else []
)
conv_count += len(session.conversations if session.conversations is not None else [])
return self.success(
data={

View File

@@ -13,9 +13,7 @@ class SystemRouterGroup(group.RouterGroup):
data={
'version': constants.semantic_version,
'debug': constants.debug_mode,
'enabled_platform_count': len(
self.ap.platform_mgr.get_running_adapters()
),
'enabled_platform_count': len(self.ap.platform_mgr.get_running_adapters()),
}
)
@@ -28,9 +26,7 @@ class SystemRouterGroup(group.RouterGroup):
return self.success(data=self.ap.task_mgr.get_tasks_dict(task_type))
@self.route(
'/tasks/<task_id>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN
)
@self.route('/tasks/<task_id>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _(task_id: str) -> str:
task = self.ap.task_mgr.get_task_by_id(int(task_id))
@@ -48,9 +44,7 @@ class SystemRouterGroup(group.RouterGroup):
await self.ap.reload(scope=scope)
return self.success()
@self.route(
'/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN
)
@self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, 'Forbidden')

View File

@@ -10,9 +10,7 @@ class UserRouterGroup(group.RouterGroup):
@self.route('/init', methods=['GET', 'POST'], auth_type=group.AuthType.NONE)
async def _() -> str:
if quart.request.method == 'GET':
return self.success(
data={'initialized': await self.ap.user_service.is_initialized()}
)
return self.success(data={'initialized': await self.ap.user_service.is_initialized()})
if await self.ap.user_service.is_initialized():
return self.fail(1, '系统已初始化')
@@ -31,17 +29,13 @@ class UserRouterGroup(group.RouterGroup):
json_data = await quart.request.json
try:
token = await self.ap.user_service.authenticate(
json_data['user'], json_data['password']
)
token = await self.ap.user_service.authenticate(json_data['user'], json_data['password'])
except argon2.exceptions.VerifyMismatchError:
return self.fail(1, '用户名或密码错误')
return self.success(data={'token': token})
@self.route(
'/check-token', methods=['GET'], auth_type=group.AuthType.USER_TOKEN
)
@self.route('/check-token', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _(user_email: str) -> str:
token = await self.ap.user_service.generate_jwt_token(user_email)

View File

@@ -70,15 +70,12 @@ class HTTPController:
@self.quart_app.route('/')
async def index():
return await quart.send_from_directory(
frontend_path, 'index.html', mimetype='text/html'
)
return await quart.send_from_directory(frontend_path, 'index.html', mimetype='text/html')
@self.quart_app.route('/<path:path>')
async def static_file(path: str):
if not (
os.path.exists(os.path.join(frontend_path, path))
and os.path.isfile(os.path.join(frontend_path, path))
os.path.exists(os.path.join(frontend_path, path)) and os.path.isfile(os.path.join(frontend_path, path))
):
if os.path.exists(os.path.join(frontend_path, path + '.html')):
path += '.html'
@@ -110,6 +107,4 @@ class HTTPController:
elif path.endswith('.txt'):
mimetype = 'text/plain'
return await quart.send_from_directory(
frontend_path, path, mimetype=mimetype
)
return await quart.send_from_directory(frontend_path, path, mimetype=mimetype)

View File

@@ -18,23 +18,16 @@ class BotService:
async def get_bots(self) -> list[dict]:
"""获取所有机器人"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_bot.Bot)
)
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot))
bots = result.all()
return [
self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot)
for bot in bots
]
return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot) for bot in bots]
async def get_bot(self, bot_uuid: str) -> dict | None:
"""获取机器人"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_bot.Bot).where(
persistence_bot.Bot.uuid == bot_uuid
)
sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
)
bot = result.first()
@@ -60,9 +53,7 @@ class BotService:
bot_data['use_pipeline_uuid'] = pipeline.uuid
bot_data['use_pipeline_name'] = pipeline.name
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_bot.Bot).values(bot_data)
)
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_bot.Bot).values(bot_data))
bot = await self.get_bot(bot_data['uuid'])
@@ -79,8 +70,7 @@ class BotService:
if 'use_pipeline_uuid' in bot_data:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.uuid
== bot_data['use_pipeline_uuid']
persistence_pipeline.LegacyPipeline.uuid == bot_data['use_pipeline_uuid']
)
)
pipeline = result.first()
@@ -90,9 +80,7 @@ class BotService:
raise Exception('Pipeline not found')
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_bot.Bot)
.values(bot_data)
.where(persistence_bot.Bot.uuid == bot_uuid)
sqlalchemy.update(persistence_bot.Bot).values(bot_data).where(persistence_bot.Bot.uuid == bot_uuid)
)
await self.ap.platform_mgr.remove_bot(bot_uuid)
@@ -108,7 +96,5 @@ class BotService:
"""删除机器人"""
await self.ap.platform_mgr.remove_bot(bot_uuid)
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_bot.Bot).where(
persistence_bot.Bot.uuid == bot_uuid
)
sqlalchemy.delete(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
)

View File

@@ -15,22 +15,15 @@ class ModelsService:
self.ap = ap
async def get_llm_models(self) -> list[dict]:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel)
)
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
models = result.all()
return [
self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
for model in models
]
return [self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) for model in models]
async def create_llm_model(self, model_data: dict) -> str:
model_data['uuid'] = str(uuid.uuid4())
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_model.LLMModel).values(**model_data)
)
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_model.LLMModel).values(**model_data))
llm_model = await self.get_llm_model(model_data['uuid'])
@@ -53,9 +46,7 @@ class ModelsService:
async def get_llm_model(self, model_uuid: str) -> dict | None:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel).where(
persistence_model.LLMModel.uuid == model_uuid
)
sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
)
model = result.first()
@@ -63,9 +54,7 @@ class ModelsService:
if model is None:
return None
return self.ap.persistence_mgr.serialize_model(
persistence_model.LLMModel, model
)
return self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
async def update_llm_model(self, model_uuid: str, model_data: dict) -> None:
if 'uuid' in model_data:
@@ -85,9 +74,7 @@ class ModelsService:
async def delete_llm_model(self, model_uuid: str) -> None:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_model.LLMModel).where(
persistence_model.LLMModel.uuid == model_uuid
)
sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
)
await self.ap.model_mgr.remove_llm_model(model_uuid)

View File

@@ -39,15 +39,11 @@ class PipelineService:
]
async def get_pipelines(self) -> list[dict]:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline)
)
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
pipelines = result.all()
return [
self.ap.persistence_mgr.serialize_model(
persistence_pipeline.LegacyPipeline, pipeline
)
self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
for pipeline in pipelines
]
@@ -63,23 +59,17 @@ class PipelineService:
if pipeline is None:
return None
return self.ap.persistence_mgr.serialize_model(
persistence_pipeline.LegacyPipeline, pipeline
)
return self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
async def create_pipeline(self, pipeline_data: dict, default: bool = False) -> str:
pipeline_data['uuid'] = str(uuid.uuid4())
pipeline_data['for_version'] = self.ap.ver_mgr.get_current_version()
pipeline_data['stages'] = default_stage_order.copy()
pipeline_data['is_default'] = default
pipeline_data['config'] = json.load(
open('templates/default-pipeline-config.json', 'r', encoding='utf-8')
)
pipeline_data['config'] = json.load(open('templates/default-pipeline-config.json', 'r', encoding='utf-8'))
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(
**pipeline_data
)
sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(**pipeline_data)
)
pipeline = await self.get_pipeline(pipeline_data['uuid'])

View File

@@ -17,9 +17,7 @@ class UserService:
self.ap = ap
async def is_initialized(self) -> bool:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(user.User).limit(1)
)
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(user.User).limit(1))
result_list = result.all()
return result_list is not None and len(result_list) > 0
@@ -30,9 +28,7 @@ class UserService:
hashed_password = ph.hash(password)
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(user.User).values(
user=user_email, password=hashed_password
)
sqlalchemy.insert(user.User).values(user=user_email, password=hashed_password)
)
async def get_user_by_email(self, user_email: str) -> user.User | None:
@@ -41,9 +37,7 @@ class UserService:
)
result_list = result.all()
return (
result_list[0] if result_list is not None and len(result_list) > 0 else None
)
return result_list[0] if result_list is not None and len(result_list) > 0 else None
async def authenticate(self, user_email: str, password: str) -> str | None:
result = await self.ap.persistence_mgr.execute_async(

View File

@@ -40,18 +40,14 @@ class CommandManager:
# 应用命令权限配置
for cls in operator.preregistered_operators:
if cls.path in self.ap.instance_config.data['command']['privilege']:
cls.lowest_privilege = self.ap.instance_config.data['command'][
'privilege'
][cls.path]
cls.lowest_privilege = self.ap.instance_config.data['command']['privilege'][cls.path]
# 实例化所有类
self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators]
# 设置所有类的子节点
for cmd in self.cmd_list:
cmd.children = [
child for child in self.cmd_list if child.parent_class == cmd.__class__
]
cmd.children = [child for child in self.cmd_list if child.parent_class == cmd.__class__]
# 初始化所有类
for cmd in self.cmd_list:
@@ -68,10 +64,7 @@ class CommandManager:
found = False
if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
for oper in operator_list:
if (
context.crt_params[0] == oper.name
or context.crt_params[0] in oper.alias
) and (
if (context.crt_params[0] == oper.name or context.crt_params[0] in oper.alias) and (
oper.parent_class is None or oper.parent_class == operator.__class__
):
found = True
@@ -85,14 +78,10 @@ class CommandManager:
if not found: # 如果下一个参数未在此节点的子节点中找到,则执行此节点或者报错
if operator is None:
yield entities.CommandReturn(
error=errors.CommandNotFoundError(context.crt_params[0])
)
yield entities.CommandReturn(error=errors.CommandNotFoundError(context.crt_params[0]))
else:
if operator.lowest_privilege > context.privilege:
yield entities.CommandReturn(
error=errors.CommandPrivilegeError(operator.name)
)
yield entities.CommandReturn(error=errors.CommandPrivilegeError(operator.name))
else:
async for ret in operator.execute(context):
yield ret
@@ -107,10 +96,7 @@ class CommandManager:
privilege = 1
if (
f'{query.launcher_type.value}_{query.launcher_id}'
in self.ap.instance_config.data['admins']
):
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
privilege = 2
ctx = entities.ExecuteContext(

View File

@@ -95,9 +95,7 @@ class CommandOperator(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""实现此方法以执行命令
支持多次yield以返回多个结果。

View File

@@ -9,9 +9,7 @@ from .. import operator, entities, errors
class CmdOperator(operator.CommandOperator):
"""命令列表"""
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行"""
if len(context.crt_params) == 0:
reply_str = '当前所有命令: \n\n'
@@ -30,16 +28,12 @@ class CmdOperator(operator.CommandOperator):
cmd = None
for _cmd in self.ap.cmd_mgr.cmd_list:
if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (
_cmd.parent_class is None
):
if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (_cmd.parent_class is None):
cmd = _cmd
break
if cmd is None:
yield entities.CommandReturn(
error=errors.CommandNotFoundError(cmd_name)
)
yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name))
else:
reply_str = f'{cmd.name}: {cmd.help}\n\n'
reply_str += f'使用方法: \n{cmd.usage}'

View File

@@ -5,55 +5,38 @@ import typing
from .. import operator, entities, errors
@operator.operator_class(
name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all'
)
@operator.operator_class(name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all')
class DelOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
delete_index = 0
if len(context.crt_params) > 0:
try:
delete_index = int(context.crt_params[0])
except Exception:
yield entities.CommandReturn(
error=errors.CommandOperationError('索引必须是整数')
)
yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数'))
return
if delete_index < 0 or delete_index >= len(context.session.conversations):
yield entities.CommandReturn(
error=errors.CommandOperationError('索引超出范围')
)
yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围'))
return
# 倒序
to_delete_index = len(context.session.conversations) - 1 - delete_index
if (
context.session.conversations[to_delete_index]
== context.session.using_conversation
):
if context.session.conversations[to_delete_index] == context.session.using_conversation:
context.session.using_conversation = None
del context.session.conversations[to_delete_index]
yield entities.CommandReturn(text=f'已删除对话: {delete_index}')
else:
yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
@operator.operator_class(
name='all', help='删除此会话的所有历史记录', parent_class=DelOperator
)
@operator.operator_class(name='all', help='删除此会话的所有历史记录', parent_class=DelOperator)
class DelAllOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
context.session.conversations = []
context.session.using_conversation = None

View File

@@ -6,9 +6,7 @@ from .. import operator, entities
@operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func')
class FuncOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> AsyncGenerator[entities.CommandReturn, None]:
reply_str = '当前已启用的内容函数: \n\n'
index = 1

View File

@@ -7,9 +7,7 @@ from .. import operator, entities
@operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>')
class HelpOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接https://langbot.app'
help += '\n发送命令 !cmd 可查看命令列表'

View File

@@ -8,36 +8,21 @@ from .. import operator, entities, errors
@operator.operator_class(name='last', help='切换到前一个对话', usage='!last')
class LastOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
# 找到当前会话的上一个会话
for index in range(len(context.session.conversations) - 1, -1, -1):
if (
context.session.conversations[index]
== context.session.using_conversation
):
if context.session.conversations[index] == context.session.using_conversation:
if index == 0:
yield entities.CommandReturn(
error=errors.CommandOperationError('已经是第一个对话了')
)
yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了'))
return
else:
context.session.using_conversation = (
context.session.conversations[index - 1]
)
time_str = (
context.session.using_conversation.create_time.strftime(
'%Y-%m-%d %H:%M:%S'
)
)
context.session.using_conversation = context.session.conversations[index - 1]
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
yield entities.CommandReturn(
text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}'
)
return
else:
yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))

View File

@@ -5,22 +5,16 @@ import typing
from .. import operator, entities, errors
@operator.operator_class(
name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>'
)
@operator.operator_class(name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>')
class ListOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
page = 0
if len(context.crt_params) > 0:
try:
page = int(context.crt_params[0] - 1)
except Exception:
yield entities.CommandReturn(
error=errors.CommandOperationError('页码应为整数')
)
yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数'))
return
record_per_page = 10
@@ -38,7 +32,9 @@ class ListOperator(operator.CommandOperator):
using_conv_index = index
if index >= page * record_per_page and index < (page + 1) * record_per_page:
content += f'{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else "无内容"}\n'
content += (
f'{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else "无内容"}\n'
)
index += 1
if content == '':

View File

@@ -14,9 +14,7 @@ from .. import operator, entities, errors
class ModelOperator(operator.CommandOperator):
"""Model命令"""
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
content = '模型列表:\n'
model_list = self.ap.model_mgr.model_list
@@ -31,15 +29,11 @@ class ModelOperator(operator.CommandOperator):
yield entities.CommandReturn(text=content.strip())
@operator.operator_class(
name='show', help='显示模型详情', privilege=2, parent_class=ModelOperator
)
@operator.operator_class(name='show', help='显示模型详情', privilege=2, parent_class=ModelOperator)
class ModelShowOperator(operator.CommandOperator):
"""Model Show命令"""
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
model_name = context.crt_params[0]
model = None
@@ -49,9 +43,7 @@ class ModelShowOperator(operator.CommandOperator):
break
if model is None:
yield entities.CommandReturn(
error=errors.CommandError(f'未找到模型 {model_name}')
)
yield entities.CommandReturn(error=errors.CommandError(f'未找到模型 {model_name}'))
else:
content = '模型详情\n'
content += f'名称: {model.name}\n'
@@ -65,15 +57,11 @@ class ModelShowOperator(operator.CommandOperator):
yield entities.CommandReturn(text=content.strip())
@operator.operator_class(
name='set', help='设置默认使用模型', privilege=2, parent_class=ModelOperator
)
@operator.operator_class(name='set', help='设置默认使用模型', privilege=2, parent_class=ModelOperator)
class ModelSetOperator(operator.CommandOperator):
"""Model Set命令"""
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
model_name = context.crt_params[0]
model = None
@@ -83,12 +71,8 @@ class ModelSetOperator(operator.CommandOperator):
break
if model is None:
yield entities.CommandReturn(
error=errors.CommandError(f'未找到模型 {model_name}')
)
yield entities.CommandReturn(error=errors.CommandError(f'未找到模型 {model_name}'))
else:
self.ap.provider_cfg.data['model'] = model_name
await self.ap.provider_cfg.dump_config()
yield entities.CommandReturn(
text=f'已设置当前使用模型为 {model_name},重置会话以生效'
)
yield entities.CommandReturn(text=f'已设置当前使用模型为 {model_name},重置会话以生效')

View File

@@ -7,36 +7,21 @@ from .. import operator, entities, errors
@operator.operator_class(name='next', help='切换到后一个对话', usage='!next')
class NextOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
# 找到当前会话的下一个会话
for index in range(len(context.session.conversations)):
if (
context.session.conversations[index]
== context.session.using_conversation
):
if context.session.conversations[index] == context.session.using_conversation:
if index == len(context.session.conversations) - 1:
yield entities.CommandReturn(
error=errors.CommandOperationError('已经是最后一个对话了')
)
yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了'))
return
else:
context.session.using_conversation = (
context.session.conversations[index + 1]
)
time_str = (
context.session.using_conversation.create_time.strftime(
'%Y-%m-%d %H:%M:%S'
)
)
context.session.using_conversation = context.session.conversations[index + 1]
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
yield entities.CommandReturn(
text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}'
)
return
else:
yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))

View File

@@ -13,9 +13,7 @@ from .. import operator, entities, errors
usage='!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>',
)
class OllamaOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
content: str = '模型列表:\n'
model_list: list = ollama.list().get('models', [])
@@ -25,9 +23,7 @@ class OllamaOperator(operator.CommandOperator):
content += f'大小: {bytes_to_mb(model["size"])}MB\n\n'
yield entities.CommandReturn(text=f'{content.strip()}')
except ollama.ResponseError:
yield entities.CommandReturn(
error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常')
)
yield entities.CommandReturn(error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常'))
def bytes_to_mb(num_bytes):
@@ -35,13 +31,9 @@ def bytes_to_mb(num_bytes):
return format(mb, '.2f')
@operator.operator_class(
name='show', help='ollama模型详情', privilege=2, parent_class=OllamaOperator
)
@operator.operator_class(name='show', help='ollama模型详情', privilege=2, parent_class=OllamaOperator)
class OllamaShowOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
content: str = '模型详情:\n'
try:
show: dict = ollama.show(model=context.crt_params[0])
@@ -60,27 +52,19 @@ class OllamaShowOperator(operator.CommandOperator):
content += json.dumps(show, indent=4)
yield entities.CommandReturn(text=content.strip())
except ollama.ResponseError:
yield entities.CommandReturn(
error=errors.CommandError('无法获取模型详情,请确认 Ollama 服务正常')
)
yield entities.CommandReturn(error=errors.CommandError('无法获取模型详情,请确认 Ollama 服务正常'))
@operator.operator_class(
name='pull', help='ollama模型拉取', privilege=2, parent_class=OllamaOperator
)
@operator.operator_class(name='pull', help='ollama模型拉取', privilege=2, parent_class=OllamaOperator)
class OllamaPullOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
model_list: list = ollama.list().get('models', [])
if context.crt_params[0] in [model['name'] for model in model_list]:
yield entities.CommandReturn(text='模型已存在')
return
except ollama.ResponseError:
yield entities.CommandReturn(
error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常')
)
yield entities.CommandReturn(error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常'))
return
on_progress: bool = False
@@ -108,13 +92,9 @@ class OllamaPullOperator(operator.CommandOperator):
yield entities.CommandReturn(text=f'拉取失败: {e.error}')
@operator.operator_class(
name='del', help='ollama模型删除', privilege=2, parent_class=OllamaOperator
)
@operator.operator_class(name='del', help='ollama模型删除', privilege=2, parent_class=OllamaOperator)
class OllamaDelOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
ret: str = ollama.delete(model=context.crt_params[0])['status']
except ollama.ResponseError as e:

View File

@@ -11,9 +11,7 @@ from .. import operator, entities, errors
usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>',
)
class PluginOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
plugin_list = self.ap.plugin_mgr.plugins()
reply_str = '所有插件({}):\n'.format(len(plugin_list))
idx = 0
@@ -32,17 +30,11 @@ class PluginOperator(operator.CommandOperator):
yield entities.CommandReturn(text=reply_str)
@operator.operator_class(
name='get', help='安装插件', privilege=2, parent_class=PluginOperator
)
@operator.operator_class(name='get', help='安装插件', privilege=2, parent_class=PluginOperator)
class PluginGetOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件仓库地址')
)
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址'))
else:
repo = context.crt_params[0]
@@ -53,22 +45,14 @@ class PluginGetOperator(operator.CommandOperator):
yield entities.CommandReturn(text='插件安装成功,请重启程序以加载插件')
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(
error=errors.CommandError('插件安装失败: ' + str(e))
)
yield entities.CommandReturn(error=errors.CommandError('插件安装失败: ' + str(e)))
@operator.operator_class(
name='update', help='更新插件', privilege=2, parent_class=PluginOperator
)
@operator.operator_class(name='update', help='更新插件', privilege=2, parent_class=PluginOperator)
class PluginUpdateOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
@@ -78,27 +62,17 @@ class PluginUpdateOperator(operator.CommandOperator):
if plugin_container is not None:
yield entities.CommandReturn(text='正在更新插件...')
await self.ap.plugin_mgr.update_plugin(plugin_name)
yield entities.CommandReturn(
text='插件更新成功,请重启程序以加载插件'
)
yield entities.CommandReturn(text='插件更新成功,请重启程序以加载插件')
else:
yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: 未找到插件')
)
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: 未找到插件'))
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: ' + str(e))
)
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
@operator.operator_class(
name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator
)
@operator.operator_class(name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator)
class PluginUpdateAllOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()]
@@ -111,32 +85,20 @@ class PluginUpdateAllOperator(operator.CommandOperator):
updated.append(plugin_name)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: ' + str(e))
)
yield entities.CommandReturn(
text='已更新插件: {}'.format(', '.join(updated))
)
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
yield entities.CommandReturn(text='已更新插件: {}'.format(', '.join(updated)))
else:
yield entities.CommandReturn(text='没有可更新的插件')
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: ' + str(e))
)
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
@operator.operator_class(
name='del', help='删除插件', privilege=2, parent_class=PluginOperator
)
@operator.operator_class(name='del', help='删除插件', privilege=2, parent_class=PluginOperator)
class PluginDelOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
@@ -146,79 +108,49 @@ class PluginDelOperator(operator.CommandOperator):
if plugin_container is not None:
yield entities.CommandReturn(text='正在删除插件...')
await self.ap.plugin_mgr.uninstall_plugin(plugin_name)
yield entities.CommandReturn(
text='插件删除成功,请重启程序以加载插件'
)
yield entities.CommandReturn(text='插件删除成功,请重启程序以加载插件')
else:
yield entities.CommandReturn(
error=errors.CommandError('插件删除失败: 未找到插件')
)
yield entities.CommandReturn(error=errors.CommandError('插件删除失败: 未找到插件'))
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(
error=errors.CommandError('插件删除失败: ' + str(e))
)
yield entities.CommandReturn(error=errors.CommandError('插件删除失败: ' + str(e)))
@operator.operator_class(
name='on', help='启用插件', privilege=2, parent_class=PluginOperator
)
@operator.operator_class(name='on', help='启用插件', privilege=2, parent_class=PluginOperator)
class PluginEnableOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
try:
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True):
yield entities.CommandReturn(
text='已启用插件: {}'.format(plugin_name)
)
yield entities.CommandReturn(text='已启用插件: {}'.format(plugin_name))
else:
yield entities.CommandReturn(
error=errors.CommandError(
'插件状态修改失败: 未找到插件 {}'.format(plugin_name)
)
error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(
error=errors.CommandError('插件状态修改失败: ' + str(e))
)
yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e)))
@operator.operator_class(
name='off', help='禁用插件', privilege=2, parent_class=PluginOperator
)
@operator.operator_class(name='off', help='禁用插件', privilege=2, parent_class=PluginOperator)
class PluginDisableOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
try:
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False):
yield entities.CommandReturn(
text='已禁用插件: {}'.format(plugin_name)
)
yield entities.CommandReturn(text='已禁用插件: {}'.format(plugin_name))
else:
yield entities.CommandReturn(
error=errors.CommandError(
'插件状态修改失败: 未找到插件 {}'.format(plugin_name)
)
error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(
error=errors.CommandError('插件状态修改失败: ' + str(e))
)
yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e)))

View File

@@ -7,14 +7,10 @@ from .. import operator, entities, errors
@operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt')
class PromptOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行"""
if context.session.using_conversation is None:
yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
else:
reply_str = '当前对话所有内容:\n\n'

View File

@@ -5,13 +5,9 @@ import typing
from .. import operator, entities, errors
@operator.operator_class(
name='resend', help='重发当前会话的最后一条消息', usage='!resend'
)
@operator.operator_class(name='resend', help='重发当前会话的最后一条消息', usage='!resend')
class ResendOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
# 回滚到最后一条用户message前
if context.session.using_conversation is None:
yield entities.CommandReturn(error=errors.CommandError('当前没有对话'))

View File

@@ -7,9 +7,7 @@ from .. import operator, entities
@operator.operator_class(name='reset', help='重置当前会话', usage='!reset')
class ResetOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行"""
context.session.using_conversation = None

View File

@@ -8,9 +8,7 @@ from .. import operator, entities, errors
@operator.operator_class(name='update', help='更新程序', usage='!update', privilege=2)
class UpdateCommand(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
yield entities.CommandReturn(text='正在进行更新...')
if await self.ap.ver_mgr.update_all():
@@ -19,6 +17,4 @@ class UpdateCommand(operator.CommandOperator):
yield entities.CommandReturn(text='当前已是最新版本')
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(
error=errors.CommandError('更新失败: ' + str(e))
)
yield entities.CommandReturn(error=errors.CommandError('更新失败: ' + str(e)))

View File

@@ -7,9 +7,7 @@ from .. import operator, entities
@operator.operator_class(name='version', help='显示版本信息', usage='!version')
class VersionCommand(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}'
try:

View File

@@ -41,9 +41,7 @@ class ConfigManager:
self.file.save_sync(self.data)
async def load_python_module_config(
config_name: str, template_name: str, completion: bool = True
) -> ConfigManager:
async def load_python_module_config(config_name: str, template_name: str, completion: bool = True) -> ConfigManager:
"""加载Python模块配置文件
Args:

View File

@@ -160,9 +160,7 @@ class Application:
"""打印访问 webui 的提示"""
if not os.path.exists(os.path.join('.', 'web/out')):
self.logger.warning(
'WebUI 文件缺失请根据文档获取https://docs.langbot.app/webui/intro.html'
)
self.logger.warning('WebUI 文件缺失请根据文档获取https://docs.langbot.app/webui/intro.html')
return
host_ip = '127.0.0.1'

View File

@@ -26,9 +26,7 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging.
if constants.debug_mode:
level = logging.DEBUG
log_file_name = 'data/logs/langbot-%s.log' % time.strftime(
'%Y-%m-%d', time.localtime()
)
log_file_name = 'data/logs/langbot-%s.log' % time.strftime('%Y-%m-%d', time.localtime())
qcg_logger = logging.getLogger('langbot')
@@ -43,9 +41,7 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging.
stream_handler = logging.StreamHandler(sys.stdout)
# stream_handler.setLevel(level)
# stream_handler.setFormatter(color_formatter)
stream_handler.stream = open(
sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1
)
stream_handler.stream = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1)
log_handlers: list[logging.Handler] = [
stream_handler,

View File

@@ -87,8 +87,7 @@ class Query(pydantic.BaseModel):
"""使用的函数,由前置处理器阶段设置"""
resp_messages: (
typing.Optional[list[llm_entities.Message]]
| typing.Optional[list[platform_message.MessageChain]]
typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]]
) = []
"""由Process阶段生成的回复消息对象列表"""
@@ -130,13 +129,9 @@ class Conversation(pydantic.BaseModel):
messages: list[llm_entities.Message]
create_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
use_llm_model: requester.RuntimeLLMModel
@@ -162,17 +157,11 @@ class Session(pydantic.BaseModel):
using_conversation: typing.Optional[Conversation] = None
conversations: typing.Optional[list[Conversation]] = pydantic.Field(
default_factory=list
)
conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
semaphore: typing.Optional[asyncio.Semaphore] = None
"""当前会话的信号量,用于限制并发"""

View File

@@ -11,16 +11,14 @@ class SensitiveWordMigration(migration.Migration):
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return os.path.exists(
'data/config/sensitive-words.json'
) and not os.path.exists('data/metadata/sensitive-words.json')
return os.path.exists('data/config/sensitive-words.json') and not os.path.exists(
'data/metadata/sensitive-words.json'
)
async def run(self):
"""执行迁移"""
# 移动文件
os.rename(
'data/config/sensitive-words.json', 'data/metadata/sensitive-words.json'
)
os.rename('data/config/sensitive-words.json', 'data/metadata/sensitive-words.json')
# 重新加载配置
await self.ap.sensitive_meta.load_config()

View File

@@ -23,9 +23,7 @@ class OpenAIConfigMigration(migration.Migration):
self.ap.provider_cfg.data['keys']['openai'] = old_openai_config['api-keys']
self.ap.provider_cfg.data['model'] = old_openai_config[
'chat-completions-params'
]['model']
self.ap.provider_cfg.data['model'] = old_openai_config['chat-completions-params']['model']
del old_openai_config['chat-completions-params']['model']

View File

@@ -15,8 +15,6 @@ class QCGCenterURLConfigMigration(migration.Migration):
"""执行迁移"""
if 'qcg-center-url' not in self.ap.system_cfg.data:
self.ap.system_cfg.data['qcg-center-url'] = (
'https://api.qchatgpt.rockchin.top/api/v2'
)
self.ap.system_cfg.data['qcg-center-url'] = 'https://api.qchatgpt.rockchin.top/api/v2'
await self.ap.system_cfg.dump_config()

View File

@@ -9,9 +9,7 @@ class AdFixwinConfigMigration(migration.Migration):
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return isinstance(
self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default'], int
)
return isinstance(self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default'], int)
async def run(self):
"""执行迁移"""
@@ -19,9 +17,7 @@ class AdFixwinConfigMigration(migration.Migration):
for session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
temp_dict = {
'window-size': 60,
'limit': self.ap.pipeline_cfg.data['rate-limit']['fixwin'][
session_name
],
'limit': self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name],
}
self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name] = temp_dict

View File

@@ -9,10 +9,7 @@ class HttpApiConfigMigration(migration.Migration):
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return (
'http-api' not in self.ap.system_cfg.data
or 'persistence' not in self.ap.system_cfg.data
)
return 'http-api' not in self.ap.system_cfg.data or 'persistence' not in self.ap.system_cfg.data
async def run(self):
"""执行迁移"""

View File

@@ -11,8 +11,7 @@ class DifyAPITimeoutParamsMigration(migration.Migration):
"""判断当前环境是否需要运行此迁移"""
return (
'timeout' not in self.ap.provider_cfg.data['dify-service-api']['chat']
or 'timeout'
not in self.ap.provider_cfg.data['dify-service-api']['workflow']
or 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['workflow']
or 'agent' not in self.ap.provider_cfg.data['dify-service-api']
)

View File

@@ -10,9 +10,7 @@ class SiliconFlowConfigMigration(migration.Migration):
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return (
'siliconflow-chat-completions' not in self.ap.provider_cfg.data['requester']
)
return 'siliconflow-chat-completions' not in self.ap.provider_cfg.data['requester']
async def run(self):
"""执行迁移"""

View File

@@ -13,17 +13,12 @@ class DifyThinkingConfigMigration(migration.Migration):
if 'options' not in self.ap.provider_cfg.data['dify-service-api']:
return True
if (
'convert-thinking-tips'
not in self.ap.provider_cfg.data['dify-service-api']['options']
):
if 'convert-thinking-tips' not in self.ap.provider_cfg.data['dify-service-api']['options']:
return True
return False
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['dify-service-api']['options'] = {
'convert-thinking-tips': 'plain'
}
self.ap.provider_cfg.data['dify-service-api']['options'] = {'convert-thinking-tips': 'plain'}
await self.ap.provider_cfg.dump_config()

View File

@@ -24,8 +24,6 @@ class GewechatFileUrlConfigMigration(migration.Migration):
if adapter['adapter'] == 'gewechat':
if 'gewechat_file_url' not in adapter:
parsed_url = urlparse(adapter['gewechat_url'])
adapter['gewechat_file_url'] = (
f'{parsed_url.scheme}://{parsed_url.hostname}:2532'
)
adapter['gewechat_file_url'] = f'{parsed_url.scheme}://{parsed_url.hostname}:2532'
await self.ap.platform_cfg.dump_config()

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("tg-dingtalk-markdown", 38)
@migration.migration_class('tg-dingtalk-markdown', 38)
class TgDingtalkMarkdownMigration(migration.Migration):
"""迁移"""
@@ -23,4 +23,3 @@ class TgDingtalkMarkdownMigration(migration.Migration):
if 'markdown_card' not in adapter:
adapter['markdown_card'] = False
await self.ap.platform_cfg.dump_config()

View File

@@ -3,20 +3,19 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("modelscope-config-completion", 39)
@migration.migration_class('modelscope-config-completion', 39)
class ModelScopeConfigCompletionMigration(migration.Migration):
"""ModelScope配置迁移
"""
"""ModelScope配置迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'modelscope-chat-completions' not in self.ap.provider_cfg.data['requester'] \
"""判断当前环境是否需要运行此迁移"""
return (
'modelscope-chat-completions' not in self.ap.provider_cfg.data['requester']
or 'modelscope' not in self.ap.provider_cfg.data['keys']
)
async def run(self):
"""执行迁移
"""
"""执行迁移"""
if 'modelscope-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['modelscope-chat-completions'] = {
'base-url': 'https://api-inference.modelscope.cn/v1',

View File

@@ -3,20 +3,19 @@ from __future__ import annotations
from .. import migration
@migration.migration_class("ppio-config", 40)
@migration.migration_class('ppio-config', 40)
class PPIOConfigMigration(migration.Migration):
"""PPIO配置迁移
"""
"""PPIO配置迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'ppio-chat-completions' not in self.ap.provider_cfg.data['requester'] \
"""判断当前环境是否需要运行此迁移"""
return (
'ppio-chat-completions' not in self.ap.provider_cfg.data['requester']
or 'ppio' not in self.ap.provider_cfg.data['keys']
)
async def run(self):
"""执行迁移
"""
"""执行迁移"""
if 'ppio-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['ppio-chat-completions'] = {
'base-url': 'https://api.ppinfra.com/v3/openai',

View File

@@ -35,9 +35,7 @@ class TaskContext:
if action is not None:
self.set_current_action(action)
self._log(
f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} | {self.current_action} | {msg}'
)
self._log(f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} | {self.current_action} | {msg}')
def to_dict(self) -> dict:
return {'current_action': self.current_action, 'log': self.log}
@@ -104,9 +102,7 @@ class TaskWrapper:
name: str = '',
label: str = '',
context: TaskContext = None,
scopes: list[core_entities.LifecycleControlScope] = [
core_entities.LifecycleControlScope.APPLICATION
],
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
):
self.id = TaskWrapper._id_index
TaskWrapper._id_index += 1
@@ -141,7 +137,9 @@ class TaskWrapper:
exception_traceback = 'Traceback (most recent call last):\n'
for frame in self.task_stack:
exception_traceback += f' File "{frame.f_code.co_filename}", line {frame.f_lineno}, in {frame.f_code.co_name}\n'
exception_traceback += (
f' File "{frame.f_code.co_filename}", line {frame.f_lineno}, in {frame.f_code.co_name}\n'
)
exception_traceback += f' {self.assume_exception().__str__()}\n'
@@ -156,13 +154,9 @@ class TaskWrapper:
'runtime': {
'done': self.task.done(),
'state': self.task._state,
'exception': self.assume_exception().__str__()
if self.assume_exception() is not None
else None,
'exception': self.assume_exception().__str__() if self.assume_exception() is not None else None,
'exception_traceback': exception_traceback,
'result': self.assume_result().__str__()
if self.assume_result() is not None
else None,
'result': self.assume_result().__str__() if self.assume_result() is not None else None,
},
}
@@ -191,13 +185,9 @@ class AsyncTaskManager:
name: str = '',
label: str = '',
context: TaskContext = None,
scopes: list[core_entities.LifecycleControlScope] = [
core_entities.LifecycleControlScope.APPLICATION
],
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
) -> TaskWrapper:
wrapper = TaskWrapper(
self.ap, coro, task_type, kind, name, label, context, scopes
)
wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context, scopes)
self.tasks.append(wrapper)
return wrapper
@@ -208,9 +198,7 @@ class AsyncTaskManager:
name: str = '',
label: str = '',
context: TaskContext = None,
scopes: list[core_entities.LifecycleControlScope] = [
core_entities.LifecycleControlScope.APPLICATION
],
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
) -> TaskWrapper:
return self.create_task(coro, 'user', kind, name, label, context, scopes)
@@ -225,9 +213,7 @@ class AsyncTaskManager:
type: str = None,
) -> dict:
return {
'tasks': [
t.to_dict() for t in self.tasks if type is None or t.task_type == type
],
'tasks': [t.to_dict() for t in self.tasks if type is None or t.task_type == type],
'id_index': TaskWrapper._id_index,
}

View File

@@ -114,9 +114,7 @@ class Component(pydantic.BaseModel):
_execution: Execution
"""组件执行"""
def __init__(
self, owner: str, manifest: typing.Dict[str, typing.Any], rel_path: str
):
def __init__(self, owner: str, manifest: typing.Dict[str, typing.Any], rel_path: str):
super().__init__(
owner=owner,
manifest=manifest,
@@ -125,19 +123,12 @@ class Component(pydantic.BaseModel):
)
self._metadata = Metadata(**manifest['metadata'])
self._spec = manifest['spec']
self._execution = (
Execution(**manifest['execution']) if 'execution' in manifest else None
)
self._execution = Execution(**manifest['execution']) if 'execution' in manifest else None
@classmethod
def is_component_manifest(cls, manifest: typing.Dict[str, typing.Any]) -> bool:
"""判断是否为组件清单"""
return (
'apiVersion' in manifest
and 'kind' in manifest
and 'metadata' in manifest
and 'spec' in manifest
)
return 'apiVersion' in manifest and 'kind' in manifest and 'metadata' in manifest and 'spec' in manifest
@property
def kind(self) -> str:
@@ -200,9 +191,7 @@ class ComponentDiscoveryEngine:
def __init__(self, ap: app.Application):
self.ap = ap
def load_component_manifest(
self, path: str, owner: str = 'builtin', no_save: bool = False
) -> Component | None:
def load_component_manifest(self, path: str, owner: str = 'builtin', no_save: bool = False) -> Component | None:
"""加载组件清单"""
with open(path, 'r', encoding='utf-8') as f:
manifest = yaml.safe_load(f)
@@ -229,18 +218,12 @@ class ComponentDiscoveryEngine:
if depth > max_depth:
return
for file in os.listdir(path):
if (not os.path.isdir(os.path.join(path, file))) and (
file.endswith('.yaml') or file.endswith('.yml')
):
comp = self.load_component_manifest(
os.path.join(path, file), owner, no_save
)
if (not os.path.isdir(os.path.join(path, file))) and (file.endswith('.yaml') or file.endswith('.yml')):
comp = self.load_component_manifest(os.path.join(path, file), owner, no_save)
if comp is not None:
components.append(comp)
elif os.path.isdir(os.path.join(path, file)):
recursive_load_component_manifests_in_dir(
os.path.join(path, file), depth + 1
)
recursive_load_component_manifests_in_dir(os.path.join(path, file), depth + 1)
recursive_load_component_manifests_in_dir(path)
return components
@@ -259,18 +242,12 @@ class ComponentDiscoveryEngine:
for dir in group['fromDirs']:
path = dir['path']
max_depth = dir['maxDepth'] if 'maxDepth' in dir else 1
components.extend(
self.load_component_manifests_in_dir(
path, owner, no_save, max_depth
)
)
components.extend(self.load_component_manifests_in_dir(path, owner, no_save, max_depth))
return components
def discover_blueprint(self, blueprint_manifest_path: str, owner: str = 'builtin'):
"""发现蓝图"""
blueprint_manifest = self.load_component_manifest(
blueprint_manifest_path, owner, no_save=True
)
blueprint_manifest = self.load_component_manifest(blueprint_manifest_path, owner, no_save=True)
if blueprint_manifest is None:
raise ValueError(f'Invalid blueprint manifest: {blueprint_manifest_path}')
assert blueprint_manifest.kind == 'Blueprint', '`Kind` must be `Blueprint`'
@@ -297,9 +274,7 @@ class ComponentDiscoveryEngine:
return []
return self.components[kind]
def find_components(
self, kind: str, component_list: typing.List[Component]
) -> typing.List[Component]:
def find_components(self, kind: str, component_list: typing.List[Component]) -> typing.List[Component]:
"""查找组件"""
result: typing.List[Component] = []
for component in component_list:

View File

@@ -16,9 +16,7 @@ class Bot(Base):
enable = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False)
use_pipeline_name = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
use_pipeline_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
created_at = sqlalchemy.Column(
sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()
)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
nullable=False,

View File

@@ -16,9 +16,7 @@ class LLMModel(Base):
api_keys = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
abilities = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=[])
extra_args = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default={})
created_at = sqlalchemy.Column(
sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()
)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
nullable=False,

View File

@@ -11,9 +11,7 @@ class LegacyPipeline(Base):
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
name = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
description = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
created_at = sqlalchemy.Column(
sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()
)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
nullable=False,
@@ -35,9 +33,7 @@ class PipelineRunRecord(Base):
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
pipeline_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
status = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
created_at = sqlalchemy.Column(
sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()
)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
nullable=False,

View File

@@ -13,9 +13,7 @@ class PluginSetting(Base):
enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True)
priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict)
created_at = sqlalchemy.Column(
sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()
)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
nullable=False,

View File

@@ -9,9 +9,7 @@ class User(Base):
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True)
user = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
password = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
created_at = sqlalchemy.Column(
sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()
)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
nullable=False,

View File

@@ -11,6 +11,4 @@ class SQLiteDatabaseManager(database.BaseDatabaseManager):
async def initialize(self) -> None:
sqlite_path = 'data/langbot.db'
self.engine = sqlalchemy_asyncio.create_async_engine(
f'sqlite+aiosqlite:///{sqlite_path}'
)
self.engine = sqlalchemy_asyncio.create_async_engine(f'sqlite+aiosqlite:///{sqlite_path}')

View File

@@ -58,24 +58,18 @@ class PersistenceManager:
for item in metadata.initial_metadata:
# check if the item exists
result = await self.execute_async(
sqlalchemy.select(metadata.Metadata).where(
metadata.Metadata.key == item['key']
)
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == item['key'])
)
row = result.first()
if row is None:
await self.execute_async(
sqlalchemy.insert(metadata.Metadata).values(item)
)
await self.execute_async(sqlalchemy.insert(metadata.Metadata).values(item))
# write default pipeline
result = await self.execute_async(sqlalchemy.select(pipeline.LegacyPipeline))
if result.first() is None:
self.ap.logger.info('Creating default pipeline...')
pipeline_config = json.load(
open('templates/default-pipeline-config.json', 'r', encoding='utf-8')
)
pipeline_config = json.load(open('templates/default-pipeline-config.json', 'r', encoding='utf-8'))
pipeline_data = {
'uuid': str(uuid.uuid4()),
@@ -87,16 +81,12 @@ class PersistenceManager:
'config': pipeline_config,
}
await self.execute_async(
sqlalchemy.insert(pipeline.LegacyPipeline).values(pipeline_data)
)
await self.execute_async(sqlalchemy.insert(pipeline.LegacyPipeline).values(pipeline_data))
# =================================
# run migrations
database_version = await self.execute_async(
sqlalchemy.select(metadata.Metadata).where(
metadata.Metadata.key == 'database_version'
)
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
)
database_version = int(database_version.fetchone()[1])
@@ -122,17 +112,11 @@ class PersistenceManager:
.values({'value': str(migration_instance.number)})
)
last_migration_number = migration_instance.number
self.ap.logger.info(
f'Migration {migration_instance.number} completed.'
)
self.ap.logger.info(f'Migration {migration_instance.number} completed.')
self.ap.logger.info(
f'Successfully upgraded database to version {last_migration_number}.'
)
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
async def execute_async(
self, *args, **kwargs
) -> sqlalchemy.engine.cursor.CursorResult:
async def execute_async(self, *args, **kwargs) -> sqlalchemy.engine.cursor.CursorResult:
async with self.get_db_engine().connect() as conn:
result = await conn.execute(*args, **kwargs)
await conn.commit()
@@ -141,9 +125,7 @@ class PersistenceManager:
def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
return self.db.get_engine()
def serialize_model(
self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base
) -> dict:
def serialize_model(self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base) -> dict:
return {
column.name: getattr(data, column.name)
if not isinstance(getattr(data, column.name), (datetime.datetime))

View File

@@ -14,9 +14,7 @@ class BanSessionCheckStage(stage.PipelineStage):
async def initialize(self, pipeline_config: dict):
pass
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
found = False
mode = query.pipeline_config['trigger']['access-control']['mode']
@@ -41,11 +39,7 @@ class BanSessionCheckStage(stage.PipelineStage):
ctn = not found
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE
if ctn
else entities.ResultType.INTERRUPT,
result_type=entities.ResultType.CONTINUE if ctn else entities.ResultType.INTERRUPT,
new_query=query,
console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}'
if not ctn
else '',
console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}' if not ctn else '',
)

View File

@@ -65,9 +65,7 @@ class ContentFilterStage(stage.PipelineStage):
"""
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
for filter in self.filter_chain:
if filter_entities.EnableStage.PRE in filter.enable_stages:
@@ -86,13 +84,9 @@ class ContentFilterStage(stage.PipelineStage):
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
message = result.replacement
query.message_chain = platform_message.MessageChain(
platform_message.Plain(message)
)
query.message_chain = platform_message.MessageChain(platform_message.Plain(message))
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
async def _post_process(
self,
@@ -103,9 +97,7 @@ class ContentFilterStage(stage.PipelineStage):
只要是 PASS 或者 MASKED 的就通过此 filter将其 replacement 设置为message进入下一个 filter
"""
if query.pipeline_config['safety']['content-filter']['scope'] == 'income-msg':
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
message = message.strip()
for filter in self.filter_chain:
@@ -127,13 +119,9 @@ class ContentFilterStage(stage.PipelineStage):
query.resp_messages[-1].content = message
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理"""
if stage_inst_name == 'PreContentFilterStage':
contain_non_text = False
@@ -147,9 +135,7 @@ class ContentFilterStage(stage.PipelineStage):
if contain_non_text:
self.ap.logger.debug('消息中包含非文本消息,跳过内容过滤器检查。')
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
return await self._pre_process(str(query.message_chain).strip(), query)
elif stage_inst_name == 'PostContentFilterStage':
@@ -162,8 +148,6 @@ class ContentFilterStage(stage.PipelineStage):
self.ap.logger.debug(
'resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。'
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')

View File

@@ -60,9 +60,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def process(
self, query: core_entities.Query, message: str = None, image_url=None
) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str = None, image_url=None) -> entities.FilterResult:
"""处理消息
分为前后阶段,具体取决于 enable_stages 的值。

View File

@@ -21,19 +21,13 @@ class BaiduCloudExamine(filter_model.ContentFilter):
BAIDU_EXAMINE_TOKEN_URL,
params={
'grant_type': 'client_credentials',
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine'][
'api-key'
],
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine'][
'api-secret'
],
'client_id': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'],
'client_secret': self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret'],
},
) as resp:
return (await resp.json())['access_token']
async def process(
self, query: core_entities.Query, message: str
) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
async with aiohttp.ClientSession() as session:
async with session.post(
BAIDU_EXAMINE_URL.format(await self._get_token()),

View File

@@ -13,9 +13,7 @@ class BanWordFilter(filter_model.ContentFilter):
async def initialize(self):
pass
async def process(
self, query: core_entities.Query, message: str
) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
found = False
for word in self.ap.sensitive_meta.data['words']:
@@ -31,9 +29,7 @@ class BanWordFilter(filter_model.ContentFilter):
self.ap.sensitive_meta.data['mask'] * len(match[i]),
)
else:
message = message.replace(
match[i], self.ap.sensitive_meta.data['mask_word']
)
message = message.replace(match[i], self.ap.sensitive_meta.data['mask_word'])
return entities.FilterResult(
level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS,

View File

@@ -16,9 +16,7 @@ class ContentIgnore(filter_model.ContentFilter):
entities.EnableStage.PRE,
]
async def process(
self, query: core_entities.Query, message: str
) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
if message.startswith(rule):

View File

@@ -16,9 +16,7 @@ class Controller:
def __init__(self, ap: app.Application):
self.ap = ap
self.semaphore = asyncio.Semaphore(
self.ap.instance_config.data['concurrency']['pipeline']
)
self.semaphore = asyncio.Semaphore(self.ap.instance_config.data['concurrency']['pipeline'])
async def consumer(self):
"""事件处理循环"""
@@ -32,9 +30,7 @@ class Controller:
for query in queries:
session = await self.ap.sess_mgr.get_session(query)
self.ap.logger.debug(
f'Checking query {query} session {session}'
)
self.ap.logger.debug(f'Checking query {query} session {session}')
if not session.semaphore.locked():
selected_query = query
@@ -55,22 +51,16 @@ class Controller:
# find pipeline
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
# Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected.
bot = await self.ap.platform_mgr.get_bot_by_uuid(
selected_query.bot_uuid
)
bot = await self.ap.platform_mgr.get_bot_by_uuid(selected_query.bot_uuid)
if bot:
pipeline = (
await self.ap.pipeline_mgr.get_pipeline_by_uuid(
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(
bot.bot_entity.use_pipeline_uuid
)
)
if pipeline:
await pipeline.run(selected_query)
async with self.ap.query_pool:
(
await self.ap.sess_mgr.get_session(selected_query)
).semaphore.release()
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
# 通知其他协程,有新的请求可以处理了
self.ap.query_pool.condition.notify_all()

View File

@@ -47,9 +47,7 @@ class LongTextProcessStage(stage.PipelineStage):
'未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。'
)
pipeline_config['output']['long-text-processing'][
'strategy'
] = 'forward'
pipeline_config['output']['long-text-processing']['strategy'] = 'forward'
except Exception:
traceback.print_exc()
self.ap.logger.error(
@@ -58,9 +56,7 @@ class LongTextProcessStage(stage.PipelineStage):
)
)
pipeline_config['output']['long-text-processing']['strategy'] = (
'forward'
)
pipeline_config['output']['long-text-processing']['strategy'] = 'forward'
for strategy_cls in strategy.preregistered_strategies:
if strategy_cls.name == config['strategy']:
@@ -71,9 +67,7 @@ class LongTextProcessStage(stage.PipelineStage):
await self.strategy_impl.initialize()
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
# 检查是否包含非 Plain 组件
contains_non_plain = False
@@ -89,11 +83,7 @@ class LongTextProcessStage(stage.PipelineStage):
> query.pipeline_config['output']['long-text-processing']['threshold']
):
query.resp_message_chain[-1] = platform_message.MessageChain(
await self.strategy_impl.process(
str(query.resp_message_chain[-1]), query
)
await self.strategy_impl.process(str(query.resp_message_chain[-1]), query)
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -13,9 +13,7 @@ Forward = platform_message.Forward
@strategy_model.strategy_class('forward')
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(
self, message: str, query: core_entities.Query
) -> list[platform_message.MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
display = ForwardMessageDiaplay(
title='群聊的聊天记录',
brief='[聊天记录]',

View File

@@ -27,18 +27,14 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
encoding='utf-8',
)
async def process(
self, message: str, query: core_entities.Query
) -> list[platform_message.MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
img_path = self.text_to_image(
text_str=message,
save_as='temp/{}.png'.format(int(time.time())),
query=query,
)
compressed_path, size = self.compress_image(
img_path, outfile='temp/{}_compressed.png'.format(int(time.time()))
)
compressed_path, size = self.compress_image(img_path, outfile='temp/{}_compressed.png'.format(int(time.time())))
with open(compressed_path, 'rb') as f:
img = f.read()
@@ -165,10 +161,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
numbers = self.indexNumber(rest_text)
for number in numbers:
if (
number[1] < point < number[1] + len(number[0])
and number[1] != 0
):
if number[1] < point < number[1] + len(number[0]) and number[1] != 0:
point = number[1]
break
@@ -181,9 +174,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
else:
continue
# 准备画布
img = Image.new(
'RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255)
)
img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255))
draw = ImageDraw.Draw(img, mode='RGBA')
self.ap.logger.debug('正在绘制图片...')

View File

@@ -49,9 +49,7 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def process(
self, message: str, query: core_entities.Query
) -> list[platform_message.MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
"""处理长文本
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法

View File

@@ -29,12 +29,8 @@ class ConversationMessageTruncator(stage.PipelineStage):
else:
raise ValueError(f'未知的截断器: {use_method}')
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理"""
query = await self.trun.truncate(query)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -79,26 +79,20 @@ class RuntimePipeline:
query.pipeline_config = self.pipeline_entity.config
await self.process_query(query)
async def _check_output(
self, query: entities.Query, result: pipeline_entities.StageProcessResult
):
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
"""检查输出"""
if result.user_notice:
# 处理str类型
if isinstance(result.user_notice, str):
result.user_notice = platform_message.MessageChain(
platform_message.Plain(result.user_notice)
)
result.user_notice = platform_message.MessageChain(platform_message.Plain(result.user_notice))
elif isinstance(result.user_notice, list):
result.user_notice = platform_message.MessageChain(*result.user_notice)
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
query.message_event, platform_events.GroupMessage
):
result.user_notice.insert(
0, platform_message.At(query.message_event.sender.id)
)
result.user_notice.insert(0, platform_message.At(query.message_event.sender.id))
await query.adapter.reply_message(
message_source=query.message_event,
@@ -150,37 +144,25 @@ class RuntimePipeline:
result = await result
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
self.ap.logger.debug(
f'Stage {stage_container.inst_name} processed query {query} res {result}'
)
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} res {result}')
await self._check_output(query, result)
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(
f'Stage {stage_container.inst_name} interrupted query {query}'
)
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query}')
break
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
query = result.new_query
elif isinstance(result, typing.AsyncGenerator): # 生成器
self.ap.logger.debug(
f'Stage {stage_container.inst_name} processed query {query} gen'
)
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} gen')
async for sub_result in result:
self.ap.logger.debug(
f'Stage {stage_container.inst_name} processed query {query} res {sub_result}'
)
self.ap.logger.debug(f'Stage {stage_container.inst_name} processed query {query} res {sub_result}')
await self._check_output(query, sub_result)
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(
f'Stage {stage_container.inst_name} interrupted query {query}'
)
self.ap.logger.debug(f'Stage {stage_container.inst_name} interrupted query {query}')
break
elif (
sub_result.result_type == pipeline_entities.ResultType.CONTINUE
):
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
query = sub_result.new_query
await self._execute_from_stage(i + 1, query)
break
@@ -214,12 +196,8 @@ class RuntimePipeline:
await self._execute_from_stage(0, query)
except Exception as e:
inst_name = (
query.current_stage.inst_name if query.current_stage else 'unknown'
)
self.ap.logger.error(
f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}'
)
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}')
self.ap.logger.debug(f'Traceback: {traceback.format_exc()}')
finally:
self.ap.logger.debug(f'Query {query} processed')
@@ -241,18 +219,14 @@ class PipelineManager:
self.pipelines = []
async def initialize(self):
self.stage_dict = {
name: cls for name, cls in stage.preregistered_stages.items()
}
self.stage_dict = {name: cls for name, cls in stage.preregistered_stages.items()}
await self.load_pipelines_from_db()
async def load_pipelines_from_db(self):
self.ap.logger.info('Loading pipelines from db...')
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline)
)
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
pipelines = result.all()
@@ -267,20 +241,14 @@ class PipelineManager:
| dict,
):
if isinstance(pipeline_entity, sqlalchemy.Row):
pipeline_entity = persistence_pipeline.LegacyPipeline(
**pipeline_entity._mapping
)
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity._mapping)
elif isinstance(pipeline_entity, dict):
pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity)
# initialize stage containers according to pipeline_entity.stages
stage_containers: list[StageInstContainer] = []
for stage_name in pipeline_entity.stages:
stage_containers.append(
StageInstContainer(
inst_name=stage_name, inst=self.stage_dict[stage_name](self.ap)
)
)
stage_containers.append(StageInstContainer(inst_name=stage_name, inst=self.stage_dict[stage_name](self.ap)))
for stage_container in stage_containers:
await stage_container.inst.initialize(pipeline_entity.config)

View File

@@ -44,9 +44,7 @@ class PreProcessor(stage.PipelineStage):
query.use_llm_model = conversation.use_llm_model
query.use_funcs = (
conversation.use_funcs
if query.use_llm_model.model_entity.abilities.__contains__('tool_call')
else None
conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('tool_call') else None
)
query.variables = {
@@ -59,10 +57,9 @@ class PreProcessor(stage.PipelineStage):
# Check if this model supports vision, if not, remove all images
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
if (
query.pipeline_config['ai']['runner']['runner'] == 'local-agent'
and not query.use_llm_model.model_entity.abilities.__contains__('vision')
):
if query.pipeline_config['ai']['runner'][
'runner'
] == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'):
for msg in query.messages:
if isinstance(msg.content, list):
for me in msg.content:
@@ -78,14 +75,11 @@ class PreProcessor(stage.PipelineStage):
content_list.append(llm_entities.ContentElement.from_text(me.text))
plain_text += me.text
elif isinstance(me, platform_message.Image):
if (
query.pipeline_config['ai']['runner']['runner'] != 'local-agent'
or query.use_llm_model.model_entity.abilities.__contains__('vision')
):
if query.pipeline_config['ai']['runner'][
'runner'
] != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__('vision'):
if me.base64 is not None:
content_list.append(
llm_entities.ContentElement.from_image_base64(me.base64)
)
content_list.append(llm_entities.ContentElement.from_image_base64(me.base64))
query.variables['user_message_text'] = plain_text
@@ -104,6 +98,4 @@ class PreProcessor(stage.PipelineStage):
query.prompt.messages = event_ctx.event.default_prompt
query.messages = event_ctx.event.prompt
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -49,13 +49,9 @@ class ChatMessageHandler(handler.MessageHandler):
query.resp_messages.append(mc)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
else:
if event_ctx.event.alter is not None:
# if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
@@ -69,34 +65,24 @@ class ChatMessageHandler(handler.MessageHandler):
runner = r(self.ap, query.pipeline_config)
break
else:
raise ValueError(
f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}'
)
raise ValueError(f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}')
async for result in runner.run(query):
query.resp_messages.append(result)
self.ap.logger.info(
f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}'
)
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
if result.content is not None:
text_length += len(result.content)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
query.session.using_conversation.messages.append(query.user_message)
query.session.using_conversation.messages.extend(query.resp_messages)
except Exception as e:
self.ap.logger.error(
f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}'
)
self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}')
hide_exception_info = query.pipeline_config['output']['misc'][
'hide-exception'
]
hide_exception_info = query.pipeline_config['output']['misc']['hide-exception']
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,

View File

@@ -21,10 +21,7 @@ class CommandHandler(handler.MessageHandler):
privilege = 1
if (
f'{query.launcher_type.value}_{query.launcher_id}'
in self.ap.instance_config.data['admins']
):
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
privilege = 2
spt = command_text.split(' ')
@@ -54,25 +51,17 @@ class CommandHandler(handler.MessageHandler):
query.resp_messages.append(mc)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
else:
if event_ctx.event.alter is not None:
query.message_chain = platform_message.MessageChain(
[platform_message.Plain(event_ctx.event.alter)]
)
query.message_chain = platform_message.MessageChain([platform_message.Plain(event_ctx.event.alter)])
session = await self.ap.sess_mgr.get_session(query)
async for ret in self.ap.cmd_mgr.execute(
command_text=command_text, query=query, session=session
):
async for ret in self.ap.cmd_mgr.execute(command_text=command_text, query=query, session=session):
if ret.error is not None:
query.resp_messages.append(
llm_entities.Message(
@@ -81,13 +70,9 @@ class CommandHandler(handler.MessageHandler):
)
)
self.ap.logger.info(
f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}'
)
self.ap.logger.info(f'命令({query.query_id})报错: {self.cut_str(str(ret.error))}')
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
elif ret.text is not None or ret.image_url is not None:
content: list[llm_entities.ContentElement] = []
@@ -95,9 +80,7 @@ class CommandHandler(handler.MessageHandler):
content.append(llm_entities.ContentElement.from_text(ret.text))
if ret.image_url is not None:
content.append(
llm_entities.ContentElement.from_image_url(ret.image_url)
)
content.append(llm_entities.ContentElement.from_image_url(ret.image_url))
query.resp_messages.append(
llm_entities.Message(
@@ -108,10 +91,6 @@ class CommandHandler(handler.MessageHandler):
self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}')
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)

View File

@@ -72,9 +72,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
if count >= limitation:
if query.pipeline_config['safety']['rate-limit']['strategy'] == 'drop':
return False
elif (
query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait'
):
elif query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait':
# 等待下一窗口
await asyncio.sleep(window_size - time.time() % window_size)

View File

@@ -15,9 +15,7 @@ from ...core import entities as core_entities
class SendResponseBackStage(stage.PipelineStage):
"""发送响应消息"""
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理"""
random_range = (
@@ -34,9 +32,7 @@ class SendResponseBackStage(stage.PipelineStage):
if query.pipeline_config['output']['misc']['at-sender'] and isinstance(
query.message_event, platform_events.GroupMessage
):
query.resp_message_chain[-1].insert(
0, platform_message.At(query.message_event.sender.id)
)
query.resp_message_chain[-1].insert(0, platform_message.At(query.message_event.sender.id))
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
@@ -46,6 +42,4 @@ class SendResponseBackStage(stage.PipelineStage):
quote_origin=quote_origin,
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -32,13 +32,9 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
await rule_inst.initialize()
self.rule_matchers.append(rule_inst)
async def process(
self, query: core_entities.Query, stage_inst_name: str
) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if query.launcher_type.value != 'group': # 只处理群消息
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
rules = query.pipeline_config['trigger']['group-respond-rules']
@@ -49,9 +45,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
# use_rule = rules[str(query.launcher_id)]
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
res = await rule_matcher.match(
str(query.message_chain), query.message_chain, use_rule, query
)
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)
if res.matching:
query.message_chain = res.replacement
@@ -60,6 +54,4 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
new_query=query,
)
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, new_query=query
)
return entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)

View File

@@ -16,10 +16,7 @@ class AtBotRule(rule_model.GroupRespondRule):
rule_dict: dict,
query: core_entities.Query,
) -> entities.RuleJudgeResult:
if (
message_chain.has(platform_message.At(query.adapter.bot_account_id))
and rule_dict['at']
):
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
if message_chain.has(

View File

@@ -18,6 +18,4 @@ class RandomRespRule(rule_model.GroupRespondRule):
) -> entities.RuleJudgeResult:
random_rate = rule_dict['random']
return entities.RuleJudgeResult(
matching=random.random() < random_rate, replacement=message_chain
)
return entities.RuleJudgeResult(matching=random.random() < random_rate, replacement=message_chain)

View File

@@ -34,29 +34,19 @@ class ResponseWrapper(stage.PipelineStage):
if isinstance(query.resp_messages[-1], platform_message.MessageChain):
query.resp_message_chain.append(query.resp_messages[-1])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
if query.resp_messages[-1].role == 'command':
query.resp_message_chain.append(
query.resp_messages[-1].get_content_platform_message_chain(
prefix_text='[bot] '
)
query.resp_messages[-1].get_content_platform_message_chain(prefix_text='[bot] ')
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
elif query.resp_messages[-1].role == 'plugin':
query.resp_message_chain.append(
query.resp_messages[-1].get_content_platform_message_chain()
)
query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain())
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, new_query=query
)
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else:
if query.resp_messages[-1].role == 'assistant':
result = query.resp_messages[-1]
@@ -77,14 +67,55 @@ class ResponseWrapper(stage.PipelineStage):
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[
fc.function.name for fc in result.tool_calls
]
funcs_called=[fc.function.name for fc in result.tool_calls]
if result.tool_calls is not None
else [],
query=query,
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
else:
query.resp_message_chain.append(result.get_content_platform_message_chain())
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query,
)
if result.tool_calls is not None and len(result.tool_calls) > 0: # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain.append(
platform_message.MessageChain([platform_message.Plain(reply_text)])
)
if query.pipeline_config['output']['misc']['track-function-calls']:
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls]
if result.tool_calls is not None
else [],
query=query,
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
@@ -98,66 +129,7 @@ class ResponseWrapper(stage.PipelineStage):
else:
query.resp_message_chain.append(
result.get_content_platform_message_chain()
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query,
)
if (
result.tool_calls is not None and len(result.tool_calls) > 0
): # 有函数调用
function_names = [tc.function.name for tc in result.tool_calls]
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain.append(
platform_message.MessageChain(
[platform_message.Plain(reply_text)]
)
)
if query.pipeline_config['output']['misc'][
'track-function-calls'
]:
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[
fc.function.name for fc in result.tool_calls
]
if result.tool_calls is not None
else [],
query=query,
)
)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
)
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(
platform_message.MessageChain(
event_ctx.event.reply
)
)
else:
query.resp_message_chain.append(
platform_message.MessageChain(
[platform_message.Plain(reply_text)]
)
platform_message.MessageChain([platform_message.Plain(reply_text)])
)
yield entities.StageProcessResult(

View File

@@ -32,9 +32,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
self.config = config
self.ap = ap
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
"""主动发送消息
Args:
@@ -66,9 +64,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
def register_listener(
self,
event_type: typing.Type[platform_message.Event],
callback: typing.Callable[
[platform_message.Event, MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None],
):
"""注册事件监听器
@@ -81,9 +77,7 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
def unregister_listener(
self,
event_type: typing.Type[platform_message.Event],
callback: typing.Callable[
[platform_message.Event, MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None],
):
"""注销事件监听器

View File

@@ -132,14 +132,10 @@ class PlatformManager:
self.adapter_dict = {}
async def initialize(self):
self.adapter_components = self.ap.discover.get_components_by_kind(
'MessagePlatformAdapter'
)
self.adapter_components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter')
adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] = {}
for component in self.adapter_components:
adapter_dict[component.metadata.name] = (
component.get_python_component_class()
)
adapter_dict[component.metadata.name] = component.get_python_component_class()
self.adapter_dict = adapter_dict
await self.load_bots_from_db()
@@ -152,9 +148,7 @@ class PlatformManager:
self.bots = []
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_bot.Bot)
)
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot))
bots = result.all()
@@ -172,13 +166,9 @@ class PlatformManager:
elif isinstance(bot_entity, dict):
bot_entity = persistence_bot.Bot(**bot_entity)
adapter_inst = self.adapter_dict[bot_entity.adapter](
bot_entity.adapter_config, self.ap
)
adapter_inst = self.adapter_dict[bot_entity.adapter](bot_entity.adapter_config, self.ap)
runtime_bot = RuntimeBot(
ap=self.ap, bot_entity=bot_entity, adapter=adapter_inst
)
runtime_bot = RuntimeBot(ap=self.ap, bot_entity=bot_entity, adapter=adapter_inst)
await runtime_bot.initialize()
@@ -209,9 +199,7 @@ class PlatformManager:
return component.to_plain_dict()
return None
def get_available_adapter_manifest_by_name(
self, name: str
) -> engine.Component | None:
def get_available_adapter_manifest_by_name(self, name: str) -> engine.Component | None:
for component in self.adapter_components:
if component.metadata.name == name:
return component

View File

@@ -58,13 +58,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
msg_list.append(aiocqhttp.MessageSegment.record(msg.path))
elif type(msg) is platform_message.Forward:
for node in msg.node_list:
msg_list.extend(
(
await AiocqhttpMessageConverter.yiri2target(
node.message_chain
)
)[0]
)
msg_list.extend((await AiocqhttpMessageConverter.yiri2target(node.message_chain))[0])
else:
msg_list.append(aiocqhttp.MessageSegment.text(str(msg)))
@@ -77,9 +71,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
yiri_msg_list = []
yiri_msg_list.append(
platform_message.Source(id=message_id, time=datetime.datetime.now())
)
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
for msg in message:
if msg.type == 'at':
@@ -94,14 +86,8 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
elif msg.type == 'text':
yiri_msg_list.append(platform_message.Plain(text=msg.data['text']))
elif msg.type == 'image':
image_base64, image_format = await image.qq_image_url_to_base64(
msg.data['url']
)
yiri_msg_list.append(
platform_message.Image(
base64=f'data:image/{image_format};base64,{image_base64}'
)
)
image_base64, image_format = await image.qq_image_url_to_base64(msg.data['url'])
yiri_msg_list.append(platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}'))
chain = platform_message.MessageChain(yiri_msg_list)
@@ -115,9 +101,7 @@ class AiocqhttpEventConverter(adapter.EventConverter):
@staticmethod
async def target2yiri(event: aiocqhttp.Event):
yiri_chain = await AiocqhttpMessageConverter.target2yiri(
event.message, event.message_id
)
yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id)
if event.message_type == 'group':
permission = 'MEMBER'
@@ -137,9 +121,7 @@ class AiocqhttpEventConverter(adapter.EventConverter):
name=event.sender['nickname'],
permission=platform_entities.Permission.Member,
),
special_title=event.sender['title']
if 'title' in event.sender
else '',
special_title=event.sender['title'] if 'title' in event.sender else '',
join_timestamp=0,
last_speak_timestamp=0,
mute_time_remaining=0,
@@ -191,9 +173,7 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
else:
self.bot = aiocqhttp.CQHttp()
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0]
if target_type == 'group':
@@ -207,14 +187,10 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
message: platform_message.MessageChain,
quote_origin: bool = False,
):
aiocq_event = await AiocqhttpEventConverter.yiri2target(
message_source, self.bot_account_id
)
aiocq_event = await AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0]
if quote_origin:
aiocq_msg = (
aiocqhttp.MessageSegment.reply(aiocq_event.message_id) + aiocq_msg
)
aiocq_msg = aiocqhttp.MessageSegment.reply(aiocq_event.message_id) + aiocq_msg
return await self.bot.send(aiocq_event, aiocq_msg)
@@ -224,16 +200,12 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
async def on_message(event: aiocqhttp.Event):
self.bot_account_id = event.self_id
try:
return await callback(
await self.event_converter.target2yiri(event), self
)
return await callback(await self.event_converter.target2yiri(event), self)
except Exception:
traceback.print_exc()
@@ -245,9 +217,7 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
return super().unregister_listener(event_type, callback)

View File

@@ -22,9 +22,7 @@ class DingTalkMessageConverter(adapter.MessageConverter):
async def target2yiri(event: DingTalkEvent, bot_name: str):
yiri_msg_list = []
yiri_msg_list.append(
platform_message.Source(
id=event.incoming_message.message_id, time=datetime.datetime.now()
)
platform_message.Source(id=event.incoming_message.message_id, time=datetime.datetime.now())
)
for atUser in event.incoming_message.at_users:
@@ -133,9 +131,7 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
content = await DingTalkMessageConverter.yiri2target(message)
await self.bot.send_message(content, incoming_message)
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
content = await DingTalkMessageConverter.yiri2target(message)
if target_type == 'person':
await self.bot.send_proactive_message_to_one(target_id, content)
@@ -145,16 +141,12 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
async def on_message(event: DingTalkEvent):
try:
return await callback(
await self.event_converter.target2yiri(
event, self.config['robot_name']
),
await self.event_converter.target2yiri(event, self.config['robot_name']),
self,
)
except Exception:
@@ -174,8 +166,6 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
async def unregister_listener(
self,
event_type: type,
callback: typing.Callable[
[platform_events.Event, MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
return super().unregister_listener(event_type, callback)

View File

@@ -45,9 +45,7 @@ class DiscordMessageConverter(adapter.MessageConverter):
with open(ele.path, 'rb') as f:
image_bytes = f.read()
image_files.append(
discord.File(fp=image_bytes, filename=f'{uuid.uuid4()}.png')
)
image_files.append(discord.File(fp=image_bytes, filename=f'{uuid.uuid4()}.png'))
elif isinstance(ele, platform_message.Plain):
text_string += ele.text
elif isinstance(ele, platform_message.Forward):
@@ -65,9 +63,7 @@ class DiscordMessageConverter(adapter.MessageConverter):
async def target2yiri(message: discord.Message) -> platform_message.MessageChain:
lb_msg_list = []
msg_create_time = datetime.datetime.fromtimestamp(
int(message.created_at.timestamp())
)
msg_create_time = datetime.datetime.fromtimestamp(int(message.created_at.timestamp()))
lb_msg_list.append(platform_message.Source(id=message.id, time=msg_create_time))
@@ -97,11 +93,7 @@ class DiscordMessageConverter(adapter.MessageConverter):
else:
mid_at_component.append(platform_message.At(target=mid_at[2:-1]))
return (
text_element_recur(text_split[0])
+ mid_at_component
+ text_element_recur(text_split[1])
)
return text_element_recur(text_split[0]) + mid_at_component + text_element_recur(text_split[1])
else:
return [platform_message.Plain(text=text_ele)]
@@ -114,11 +106,7 @@ class DiscordMessageConverter(adapter.MessageConverter):
image_data = await response.read()
image_base64 = base64.b64encode(image_data).decode('utf-8')
image_format = response.headers['Content-Type']
element_list.append(
platform_message.Image(
base64=f'data:{image_format};base64,{image_base64}'
)
)
element_list.append(platform_message.Image(base64=f'data:{image_format};base64,{image_base64}'))
return platform_message.MessageChain(element_list)
@@ -208,9 +196,7 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
self.bot = MyClient(intents=intents, **args)
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
pass
async def reply_message(
@@ -243,18 +229,14 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners[event_type] = callback
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners.pop(event_type)

View File

@@ -40,14 +40,10 @@ class GewechatMessageConverter(adapter.MessageConverter):
content_list.append({'type': 'image', 'image': component.url})
elif isinstance(component, platform_message.Voice):
content_list.append(
{'type': 'voice', 'url': component.url, 'length': component.length}
)
content_list.append({'type': 'voice', 'url': component.url, 'length': component.length})
elif isinstance(component, platform_message.Forward):
for node in component.node_list:
content_list.extend(
await GewechatMessageConverter.yiri2target(node.message_chain)
)
content_list.extend(await GewechatMessageConverter.yiri2target(node.message_chain))
content_list.append({'type': 'image', 'image': component.url})
elif isinstance(component, platform_message.WeChatMiniPrograms):
content_list.append(
@@ -88,44 +84,26 @@ class GewechatMessageConverter(adapter.MessageConverter):
}
)
elif isinstance(component, platform_message.WeChatForwardLink):
content_list.append(
{'type': 'WeChatForwardLink', 'xml_data': component.xml_data}
)
content_list.append({'type': 'WeChatForwardLink', 'xml_data': component.xml_data})
elif isinstance(component, platform_message.Voice):
content_list.append(
{'type': 'voice', 'url': component.url, 'length': component.length}
)
content_list.append({'type': 'voice', 'url': component.url, 'length': component.length})
elif isinstance(component, platform_message.WeChatForwardImage):
content_list.append(
{'type': 'WeChatForwardImage', 'xml_data': component.xml_data}
)
content_list.append({'type': 'WeChatForwardImage', 'xml_data': component.xml_data})
elif isinstance(component, platform_message.WeChatForwardFile):
content_list.append(
{'type': 'WeChatForwardFile', 'xml_data': component.xml_data}
)
content_list.append({'type': 'WeChatForwardFile', 'xml_data': component.xml_data})
elif isinstance(component, platform_message.WeChatAppMsg):
content_list.append(
{'type': 'WeChatAppMsg', 'app_msg': component.app_msg}
)
content_list.append({'type': 'WeChatAppMsg', 'app_msg': component.app_msg})
# 引用消息转发
elif isinstance(component, platform_message.WeChatForwardQuote):
content_list.append(
{'type': 'WeChatAppMsg', 'app_msg': component.app_msg}
)
content_list.append({'type': 'WeChatAppMsg', 'app_msg': component.app_msg})
elif isinstance(component, platform_message.Forward):
for node in component.node_list:
if node.message_chain:
content_list.extend(
await GewechatMessageConverter.yiri2target(
node.message_chain
)
)
content_list.extend(await GewechatMessageConverter.yiri2target(node.message_chain))
return content_list
async def target2yiri(
self, message: dict, bot_account_id: str
) -> platform_message.MessageChain:
async def target2yiri(self, message: dict, bot_account_id: str) -> platform_message.MessageChain:
"""外部消息转平台消息"""
# 数据预处理
message_list = []
@@ -163,28 +141,20 @@ class GewechatMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain(message_list)
async def _handler_text(
self, message: Optional[dict], content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_text(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理文本消息 (msg_type=1)"""
if message and self._is_group_message(message):
pattern = r'@\S{1,20}'
content_no_preifx = re.sub(pattern, '', content_no_preifx)
return platform_message.MessageChain(
[platform_message.Plain(content_no_preifx)]
)
return platform_message.MessageChain([platform_message.Plain(content_no_preifx)])
async def _handler_image(
self, message: Optional[dict], content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_image(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理图像消息 (msg_type=3)"""
try:
image_xml = content_no_preifx
if not image_xml:
return platform_message.MessageChain(
[platform_message.Unknown('[图片内容为空]')]
)
return platform_message.MessageChain([platform_message.Unknown('[图片内容为空]')])
base64_str, image_format = await image.get_gewechat_image_base64(
gewechat_url=self.config['gewechat_url'],
@@ -196,21 +166,15 @@ class GewechatMessageConverter(adapter.MessageConverter):
)
elements = [
platform_message.Image(
base64=f'data:image/{image_format};base64,{base64_str}'
),
platform_message.Image(base64=f'data:image/{image_format};base64,{base64_str}'),
platform_message.WeChatForwardImage(xml_data=image_xml), # 微信消息转发
]
return platform_message.MessageChain(elements)
except Exception as e:
print(f'处理图片失败: {str(e)}')
return platform_message.MessageChain(
[platform_message.Unknown('[图片处理失败]')]
)
return platform_message.MessageChain([platform_message.Unknown('[图片处理失败]')])
async def _handler_voice(
self, message: Optional[dict], content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_voice(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理语音消息 (msg_type=34)"""
message_List = []
try:
@@ -223,9 +187,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain(message_List)
# 转换为平台支持的语音格式(如 Silk 格式)
voice_element = platform_message.Voice(
base64=f'data:audio/silk;base64,{audio_base64}'
)
voice_element = platform_message.Voice(base64=f'data:audio/silk;base64,{audio_base64}')
message_List.append(voice_element)
except KeyError as e:
@@ -237,9 +199,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain(message_List)
async def _handler_compound(
self, message: Optional[dict], content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_compound(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理复合消息 (msg_type=49),根据子类型分派"""
try:
xml_data = ET.fromstring(content_no_preifx)
@@ -254,33 +214,21 @@ class GewechatMessageConverter(adapter.MessageConverter):
'6': self._handler_compound_file,
'33': self._handler_compound_mini_program,
'36': self._handler_compound_mini_program,
'2000': partial(
self._handler_compound_unsupported, text='[转账消息]'
),
'2001': partial(
self._handler_compound_unsupported, text='[红包消息]'
),
'51': partial(
self._handler_compound_unsupported, text='[视频号消息]'
),
'2000': partial(self._handler_compound_unsupported, text='[转账消息]'),
'2001': partial(self._handler_compound_unsupported, text='[红包消息]'),
'51': partial(self._handler_compound_unsupported, text='[视频号消息]'),
}
handler = sub_handler_map.get(
data_type, self._handler_compound_unsupported
)
handler = sub_handler_map.get(data_type, self._handler_compound_unsupported)
return await handler(
message=message, # 原始msg
xml_data=xml_data, # xml数据
)
else:
return platform_message.MessageChain(
[platform_message.Unknown(text=content_no_preifx)]
)
return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)])
except Exception as e:
print(f'解析复合消息失败: {str(e)}')
return platform_message.MessageChain(
[platform_message.Unknown(text=content_no_preifx)]
)
return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)])
async def _handler_compound_quote(
self, message: Optional[dict], xml_data: ET.Element
@@ -296,9 +244,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
user_data = appmsg_data.findtext('.//title') or ''
quote_data = appmsg_data.find('.//refermsg').findtext('.//content')
message_list.append(
platform_message.WeChatForwardQuote(
app_msg=ET.tostring(appmsg_data, encoding='unicode')
)
platform_message.WeChatForwardQuote(app_msg=ET.tostring(appmsg_data, encoding='unicode'))
)
# quote_data原始的消息
if quote_data:
@@ -311,22 +257,14 @@ class GewechatMessageConverter(adapter.MessageConverter):
# 引用消息展开
quote_data_xml = ET.fromstring(quote_data)
if quote_data_xml.find('img'):
quote_data_message_list.extend(
await self._handler_image(None, quote_data)
)
quote_data_message_list.extend(await self._handler_image(None, quote_data))
elif quote_data_xml.find('voicemsg'):
quote_data_message_list.extend(
await self._handler_voice(None, quote_data)
)
quote_data_message_list.extend(await self._handler_voice(None, quote_data))
elif quote_data_xml.find('videomsg'):
quote_data_message_list.extend(
await self._handler_default(None, quote_data)
) # 先不处理
quote_data_message_list.extend(await self._handler_default(None, quote_data)) # 先不处理
else:
# appmsg
quote_data_message_list.extend(
await self._handler_compound(None, quote_data)
)
quote_data_message_list.extend(await self._handler_compound(None, quote_data))
except Exception as e:
print(f'处理引用消息异常 expcetion:{e}')
quote_data_message_list.append(platform_message.Plain(quote_data))
@@ -351,18 +289,12 @@ class GewechatMessageConverter(adapter.MessageConverter):
# print(f"quote_message_chain plain [msg_type={comp.type}][message={comp.text}]")
return platform_message.MessageChain(message_list)
async def _handler_compound_file(
self, message: dict, xml_data: ET.Element
) -> platform_message.MessageChain:
async def _handler_compound_file(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain:
"""处理文件消息 (data_type=6)"""
xml_data_str = ET.tostring(xml_data, encoding='unicode')
return platform_message.MessageChain(
[platform_message.WeChatForwardFile(xml_data=xml_data_str)]
)
return platform_message.MessageChain([platform_message.WeChatForwardFile(xml_data=xml_data_str)])
async def _handler_compound_link(
self, message: dict, xml_data: ET.Element
) -> platform_message.MessageChain:
async def _handler_compound_link(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain:
"""处理链接消息(如公众号文章、外部网页)"""
message_list = []
try:
@@ -381,9 +313,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
# 转发消息
xml_data_str = ET.tostring(xml_data, encoding='unicode')
# print(xml_data_str)
message_list.append(
platform_message.WeChatForwardLink(xml_data=xml_data_str)
)
message_list.append(platform_message.WeChatForwardLink(xml_data=xml_data_str))
except Exception as e:
print(f'解析链接消息失败: {str(e)}')
return platform_message.MessageChain(message_list)
@@ -393,21 +323,15 @@ class GewechatMessageConverter(adapter.MessageConverter):
) -> platform_message.MessageChain:
"""处理小程序消息(如小程序卡片、服务通知)"""
xml_data_str = ET.tostring(xml_data, encoding='unicode')
return platform_message.MessageChain(
[platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)]
)
return platform_message.MessageChain([platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)])
async def _handler_default(
self, message: Optional[dict], content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_default(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理未知消息类型"""
if message:
msg_type = message['Data']['MsgType']
else:
msg_type = ''
return platform_message.MessageChain(
[platform_message.Unknown(text=f'[未知消息类型 msg_type:{msg_type}]')]
)
return platform_message.MessageChain([platform_message.Unknown(text=f'[未知消息类型 msg_type:{msg_type}]')])
def _handler_compound_unsupported(
self, message: dict, xml_data: str, text: Optional[str] = None
@@ -416,11 +340,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
if not text:
text = f'[xml_data={xml_data}]'
content_list = []
content_list.append(
platform_message.Unknown(
text=f'[处理未支持复合消息类型[msg_type=49]|{text}'
)
)
content_list.append(platform_message.Unknown(text=f'[处理未支持复合消息类型[msg_type=49]|{text}'))
return platform_message.MessageChain(content_list)
@@ -448,9 +368,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
appmsg_data = xml_data.find('.//appmsg')
tousername = message['Wxid']
if appmsg_data: # 接收方: 所属微信的wxid
quote_id = appmsg_data.find('.//refermsg').findtext(
'.//chatusr'
) # 引用消息的原发送者
quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') # 引用消息的原发送者
ats_bot = ats_bot or (quote_id == tousername)
except Exception as e:
print(f'_ats_bot got except: {e}')
@@ -458,9 +376,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
return ats_bot
# 提取一下content前面的sender_id, 和去掉前缀的内容
def _extract_content_and_sender(
self, raw_content: str
) -> Tuple[str, Optional[str]]:
def _extract_content_and_sender(self, raw_content: str) -> Tuple[str, Optional[str]]:
try:
# 检查消息开头,如果有 wxid_sbitaz0mt65n22:\n 则删掉
# add: 有些用户的wxid不是上述格式。换成user_name:
@@ -490,21 +406,17 @@ class GewechatEventConverter(adapter.EventConverter):
async def yiri2target(event: platform_events.MessageEvent) -> dict:
pass
async def target2yiri(
self, event: dict, bot_account_id: str
) -> platform_events.MessageEvent:
async def target2yiri(self, event: dict, bot_account_id: str) -> platform_events.MessageEvent:
# print(event)
# 排除自己发消息回调回答问题
if event['Wxid'] == event['Data']['FromUserName']['string']:
return None
# 排除公众号以及微信团队消息
if event['Data']['FromUserName']['string'].startswith('gh_') or event['Data'][
'FromUserName'
]['string'].startswith('weixin'):
if event['Data']['FromUserName']['string'].startswith('gh_') or event['Data']['FromUserName'][
'string'
].startswith('weixin'):
return None
message_chain = await self.message_converter.target2yiri(
copy.deepcopy(event), bot_account_id
)
message_chain = await self.message_converter.target2yiri(copy.deepcopy(event), bot_account_id)
if not message_chain:
return None
@@ -589,9 +501,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
return 'ok'
elif 'TypeName' in data and data['TypeName'] == 'AddMsg':
try:
event = await self.event_converter.target2yiri(
data.copy(), self.bot_account_id
)
event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id)
except Exception:
traceback.print_exc()
@@ -600,9 +510,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
return 'ok'
async def _handle_message(
self, message: platform_message.MessageChain, target_id: str
):
async def _handle_message(self, message: platform_message.MessageChain, target_id: str):
"""统一消息处理核心逻辑"""
content_list = await self.message_converter.yiri2target(message)
at_targets = [item['target'] for item in content_list if item['type'] == 'at']
@@ -611,9 +519,9 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
at_targets = at_targets or []
member_info = []
if at_targets:
member_info = self.bot.get_chatroom_member_detail(
self.config['app_id'], target_id, at_targets[::-1]
)['data']
member_info = self.bot.get_chatroom_member_detail(self.config['app_id'], target_id, at_targets[::-1])[
'data'
]
# 处理消息组件
for msg in content_list:
@@ -694,9 +602,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}')
continue
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
"""主动发送消息"""
return await self._handle_message(message, target_id)
@@ -708,9 +614,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
):
"""回复消息"""
if message_source.source_platform_object:
target_id = message_source.source_platform_object['Data']['FromUserName'][
'string'
]
target_id = message_source.source_platform_object['Data']['FromUserName']['string']
return await self._handle_message(message, target_id)
async def is_muted(self, group_id: int) -> bool:
@@ -719,18 +623,14 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners[event_type] = callback
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
pass
@@ -742,14 +642,10 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
json={'app_id': self.config['app_id']},
) as response:
if response.status != 200:
raise Exception(
f'获取gewechat token失败: {await response.text()}'
)
raise Exception(f'获取gewechat token失败: {await response.text()}')
self.config['token'] = (await response.json())['data']
self.bot = gewechat_client.GewechatClient(
f'{self.config["gewechat_url"]}/v2/api', self.config['token']
)
self.bot = gewechat_client.GewechatClient(f'{self.config["gewechat_url"]}/v2/api', self.config['token'])
def gewechat_login_process():
app_id, error_msg = self.bot.login(self.config['app_id'])

View File

@@ -71,14 +71,10 @@ class LarkMessageConverter(adapter.MessageConverter):
pending_paragraph.append({'tag': 'md', 'text': text})
except UnicodeError:
# If still fails, replace invalid characters
text = msg.text.encode('utf-8', errors='replace').decode(
'utf-8'
)
text = msg.text.encode('utf-8', errors='replace').decode('utf-8')
pending_paragraph.append({'tag': 'md', 'text': text})
elif isinstance(msg, platform_message.At):
pending_paragraph.append(
{'tag': 'at', 'user_id': msg.target, 'style': []}
)
pending_paragraph.append({'tag': 'at', 'user_id': msg.target, 'style': []})
elif isinstance(msg, platform_message.AtAll):
pending_paragraph.append({'tag': 'at', 'user_id': 'all', 'style': []})
elif isinstance(msg, platform_message.Image):
@@ -166,11 +162,7 @@ class LarkMessageConverter(adapter.MessageConverter):
os.unlink(temp_file.name)
elif isinstance(msg, platform_message.Forward):
for node in msg.node_list:
message_elements.extend(
await LarkMessageConverter.yiri2target(
node.message_chain, api_client
)
)
message_elements.extend(await LarkMessageConverter.yiri2target(node.message_chain, api_client))
if pending_paragraph:
message_elements.append(pending_paragraph)
@@ -186,13 +178,9 @@ class LarkMessageConverter(adapter.MessageConverter):
lb_msg_list = []
msg_create_time = datetime.datetime.fromtimestamp(
int(message.create_time) / 1000
)
msg_create_time = datetime.datetime.fromtimestamp(int(message.create_time) / 1000)
lb_msg_list.append(
platform_message.Source(id=message.message_id, time=msg_create_time)
)
lb_msg_list.append(platform_message.Source(id=message.message_id, time=msg_create_time))
if message.message_type == 'text':
element_list = []
@@ -222,9 +210,7 @@ class LarkMessageConverter(adapter.MessageConverter):
left_text = text_split[0]
right_text = text_split[1]
new_list.extend(
text_element_recur({'tag': 'text', 'text': left_text, 'style': []})
)
new_list.extend(text_element_recur({'tag': 'text', 'text': left_text, 'style': []}))
new_list.append(
{
@@ -235,15 +221,11 @@ class LarkMessageConverter(adapter.MessageConverter):
}
)
new_list.extend(
text_element_recur({'tag': 'text', 'text': right_text, 'style': []})
)
new_list.extend(text_element_recur({'tag': 'text', 'text': right_text, 'style': []}))
return new_list
element_list = text_element_recur(
{'tag': 'text', 'text': message_content['text'], 'style': []}
)
element_list = text_element_recur({'tag': 'text', 'text': message_content['text'], 'style': []})
message_content = {'title': '', 'content': element_list}
@@ -258,9 +240,7 @@ class LarkMessageConverter(adapter.MessageConverter):
message_content['content'] = new_list
elif message.message_type == 'image':
message_content['content'] = [
{'tag': 'img', 'image_key': message_content['image_key'], 'style': []}
]
message_content['content'] = [{'tag': 'img', 'image_key': message_content['image_key'], 'style': []}]
for ele in message_content['content']:
if ele['tag'] == 'text':
@@ -278,9 +258,7 @@ class LarkMessageConverter(adapter.MessageConverter):
.build()
)
response: GetMessageResourceResponse = (
await api_client.im.v1.message_resource.aget(request)
)
response: GetMessageResourceResponse = await api_client.im.v1.message_resource.aget(request)
if not response.success():
raise Exception(
@@ -292,11 +270,7 @@ class LarkMessageConverter(adapter.MessageConverter):
image_format = response.raw.headers['content-type']
lb_msg_list.append(
platform_message.Image(
base64=f'data:{image_format};base64,{image_base64}'
)
)
lb_msg_list.append(platform_message.Image(base64=f'data:{image_format};base64,{image_base64}'))
return platform_message.MessageChain(lb_msg_list)
@@ -312,9 +286,7 @@ class LarkEventConverter(adapter.EventConverter):
async def target2yiri(
event: lark_oapi.im.v1.P2ImMessageReceiveV1, api_client: lark_oapi.Client
) -> platform_events.Event:
message_chain = await LarkMessageConverter.target2yiri(
event.event.message, api_client
)
message_chain = await LarkMessageConverter.target2yiri(event.event.message, api_client)
if event.event.message.chat_type == 'p2p':
return platform_events.FriendMessage(
@@ -402,9 +374,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
p2v1.schema = context.schema
if 'im.message.receive_v1' == type:
try:
event = await self.event_converter.target2yiri(
p2v1, self.api_client
)
event = await self.event_converter.target2yiri(p2v1, self.api_client)
except Exception:
traceback.print_exc()
@@ -425,26 +395,15 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
asyncio.create_task(on_message(event))
event_handler = (
lark_oapi.EventDispatcherHandler.builder('', '')
.register_p2_im_message_receive_v1(sync_on_message)
.build()
lark_oapi.EventDispatcherHandler.builder('', '').register_p2_im_message_receive_v1(sync_on_message).build()
)
self.bot_account_id = config['bot_name']
self.bot = lark_oapi.ws.Client(
config['app_id'], config['app_secret'], event_handler=event_handler
)
self.api_client = (
lark_oapi.Client.builder()
.app_id(config['app_id'])
.app_secret(config['app_secret'])
.build()
)
self.bot = lark_oapi.ws.Client(config['app_id'], config['app_secret'], event_handler=event_handler)
self.api_client = lark_oapi.Client.builder().app_id(config['app_id']).app_secret(config['app_secret']).build()
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
pass
async def reply_message(
@@ -455,9 +414,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
):
# 不再需要了因为message_id已经被包含到message_chain中
# lark_event = await self.event_converter.yiri2target(message_source)
lark_message = await self.message_converter.yiri2target(
message, self.api_client
)
lark_message = await self.message_converter.yiri2target(message, self.api_client)
final_content = {
'zh_cn': {
@@ -480,9 +437,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
.build()
)
response: ReplyMessageResponse = await self.api_client.im.v1.message.areply(
request
)
response: ReplyMessageResponse = await self.api_client.im.v1.message.areply(request)
if not response.success():
raise Exception(
@@ -495,18 +450,14 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners[event_type] = callback
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners.pop(event_type)

View File

@@ -29,9 +29,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
elif type(message_chain) is str:
msg_list = [platform_message.Plain(message_chain)]
else:
raise Exception(
'Unknown message type: ' + str(message_chain) + str(type(message_chain))
)
raise Exception('Unknown message type: ' + str(message_chain) + str(type(message_chain)))
nakuru_msg_list = []
@@ -63,9 +61,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
# 遍历并转换
for yiri_forward_node in yiri_forward_node_list:
try:
content_list = NakuruProjectMessageConverter.yiri2target(
yiri_forward_node.message_chain
)
content_list = NakuruProjectMessageConverter.yiri2target(yiri_forward_node.message_chain)
nakuru_forward_node = nkc.Node(
name=yiri_forward_node.sender_name,
uin=yiri_forward_node.sender_id,
@@ -87,9 +83,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
return nakuru_msg_list
@staticmethod
def target2yiri(
message_chain: typing.Any, message_id: int = -1
) -> platform_message.MessageChain:
def target2yiri(message_chain: typing.Any, message_id: int = -1) -> platform_message.MessageChain:
"""将Yiri的消息链转换为YiriMirai的消息链"""
assert type(message_chain) is list
@@ -97,9 +91,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
import datetime
# 添加Source组件以标记message_id等信息
yiri_msg_list.append(
platform_message.Source(id=message_id, time=datetime.datetime.now())
)
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
for component in message_chain:
if type(component) is nkc.Plain:
yiri_msg_list.append(platform_message.Plain(text=component.text))
@@ -130,9 +122,7 @@ class NakuruProjectEventConverter(adapter_model.EventConverter):
@staticmethod
def target2yiri(event: typing.Any) -> platform_events.Event:
yiri_chain = NakuruProjectMessageConverter.target2yiri(
event.message, event.message_id
)
yiri_chain = NakuruProjectMessageConverter.target2yiri(event.message, event.message_id)
if type(event) is nakuru.FriendMessage: # 私聊消息事件
return platform_events.FriendMessage(
sender=platform_entities.Friend(
@@ -206,9 +196,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
):
task = None
converted_msg = (
self.message_converter.yiri2target(message) if not converted else message
)
converted_msg = self.message_converter.yiri2target(message) if not converted else message
# 检查是否有转发消息
has_forward = False
@@ -250,13 +238,9 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
),
)
if type(message_source) is platform_events.GroupMessage:
await self.send_message(
'group', message_source.sender.group.id, message, converted=True
)
await self.send_message('group', message_source.sender.group.id, message, converted=True)
elif type(message_source) is platform_events.FriendMessage:
await self.send_message(
'person', message_source.sender.id, message, converted=True
)
await self.send_message('person', message_source.sender.id, message, converted=True)
else:
raise Exception('Unknown message source type: ' + str(type(message_source)))
@@ -264,17 +248,13 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
import time
# 检查是否被禁言
group_member_info = asyncio.run(
self.bot.getGroupMemberInfo(group_id, self.bot_account_id)
)
group_member_info = asyncio.run(self.bot.getGroupMemberInfo(group_id, self.bot_account_id))
return group_member_info.shut_up_timestamp > int(time.time())
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter_model.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
):
try:
source_cls = NakuruProjectEventConverter.yiri2target(event_type)
@@ -301,9 +281,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter_model.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
):
nakuru_event_name = self.event_converter.yiri2target(event_type).__name__
@@ -312,10 +290,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
# 从本对象的监听器列表中查找并删除
target_wrapper = None
for listener in self.listener_list:
if (
listener['event_type'] == event_type
and listener['callable'] == callback
):
if listener['event_type'] == event_type and listener['callable'] == callback:
target_wrapper = listener['wrapper']
self.listener_list.remove(listener)
break
@@ -334,14 +309,8 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
import requests
resp = requests.get(
url='http://{}:{}/get_login_info'.format(
self.cfg['host'], self.cfg['http_port']
),
headers={
'Authorization': 'Bearer ' + self.cfg['token']
if 'token' in self.cfg
else ''
},
url='http://{}:{}/get_login_info'.format(self.cfg['host'], self.cfg['http_port']),
headers={'Authorization': 'Bearer ' + self.cfg['token'] if 'token' in self.cfg else ''},
timeout=5,
proxies=None,
)
@@ -349,9 +318,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
raise Exception('go-cqhttp拒绝访问请检查配置文件中nakuru适配器的配置')
self.bot_account_id = int(resp.json()['data']['user_id'])
except Exception:
raise Exception(
'获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确'
)
raise Exception('获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确')
await self.bot._run()
self.ap.logger.info('运行 Nakuru 适配器')
while True:

View File

@@ -25,9 +25,7 @@ class OAMessageConverter(adapter.MessageConverter):
@staticmethod
async def target2yiri(message: str, message_id=-1):
yiri_msg_list = []
yiri_msg_list.append(
platform_message.Source(id=message_id, time=datetime.datetime.now())
)
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
yiri_msg_list.append(platform_message.Plain(text=message))
chain = platform_message.MessageChain(yiri_msg_list)
@@ -39,9 +37,7 @@ class OAEventConverter(adapter.EventConverter):
@staticmethod
async def target2yiri(event: OAEvent):
if event.type == 'text':
yiri_chain = await OAMessageConverter.target2yiri(
event.message, event.message_id
)
yiri_chain = await OAMessageConverter.target2yiri(event.message, event.message_id)
friend = platform_entities.Friend(
id=event.user_id,
@@ -81,9 +77,7 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
]
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise ParamNotEnoughError(
'微信公众号缺少相关配置项,请查看文档或联系管理员'
)
raise ParamNotEnoughError('微信公众号缺少相关配置项,请查看文档或联系管理员')
if self.config['Mode'] == 'drop':
self.bot = OAClient(
@@ -114,28 +108,20 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
await self.bot.set_message(message_source.message_chain.message_id, content)
elif isinstance(self.bot, OAClientForLongerResponse):
from_user = message_source.sender.id
await self.bot.set_message(
from_user, message_source.message_chain.message_id, content
)
await self.bot.set_message(from_user, message_source.message_chain.message_id, content)
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
pass
def register_listener(
self,
event_type: type,
callback: typing.Callable[
[platform_events.Event, MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
async def on_message(event: OAEvent):
self.bot_account_id = event.receiver_id
try:
return await callback(
await self.event_converter.target2yiri(event), self
)
return await callback(await self.event_converter.target2yiri(event), self)
except Exception:
traceback.print_exc()
@@ -161,8 +147,6 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
async def unregister_listener(
self,
event_type: type,
callback: typing.Callable[
[platform_events.Event, MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
return super().unregister_listener(event_type, callback)

View File

@@ -147,9 +147,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
elif type(message_chain) is str:
msg_list = [platform_message.Plain(text=message_chain)]
else:
raise Exception(
'Unknown message type: ' + str(message_chain) + str(type(message_chain))
)
raise Exception('Unknown message type: ' + str(message_chain) + str(type(message_chain)))
offcial_messages: list[dict] = []
"""
@@ -172,19 +170,13 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
if component.url is not None:
offcial_messages.append({'type': 'image', 'content': component.url})
elif component.path is not None:
offcial_messages.append(
{'type': 'file_image', 'content': component.path}
)
offcial_messages.append({'type': 'file_image', 'content': component.path})
elif type(component) is platform_message.At:
offcial_messages.append({'type': 'at', 'content': ''})
elif type(component) is platform_message.AtAll:
print(
'上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。'
)
print('上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。')
elif type(component) is platform_message.Voice:
print(
'上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。'
)
print('上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。')
elif type(component) is forward.Forward:
# 转发消息
yiri_forward_node_list = component.node_list
@@ -195,9 +187,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
message_chain = yiri_forward_node.message_chain
# 平铺
offcial_messages.extend(
OfficialMessageConverter.yiri2target(message_chain)
)
offcial_messages.extend(OfficialMessageConverter.yiri2target(message_chain))
except Exception:
import traceback
@@ -219,11 +209,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
yiri_msg_list = []
# 存id
yiri_msg_list.append(
platform_message.Source(
id=save_msg_id(message_id), time=datetime.datetime.now()
)
)
yiri_msg_list.append(platform_message.Source(id=save_msg_id(message_id), time=datetime.datetime.now()))
if type(message) not in [botpy_message.DirectMessage, botpy_message.C2CMessage]:
yiri_msg_list.append(platform_message.At(target=bot_account_id))
@@ -239,9 +225,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
if attachment.content_type.startswith('image'):
yiri_msg_list.append(platform_message.Image(url=attachment.url))
else:
logging.warning(
'不支持的附件类型:' + attachment.content_type + ',忽略此附件。'
)
logging.warning('不支持的附件类型:' + attachment.content_type + ',忽略此附件。')
content = re.sub(r'<@!\d+>', '', str(message.content))
if content.strip() != '':
@@ -264,9 +248,7 @@ class OfficialEventConverter(adapter_model.EventConverter):
elif event == platform_events.FriendMessage:
return botpy_message.DirectMessage
else:
raise Exception(
'未支持转换的事件类型(YiriMirai -> Official): ' + str(event)
)
raise Exception('未支持转换的事件类型(YiriMirai -> Official): ' + str(event))
def target2yiri(
self,
@@ -297,21 +279,13 @@ class OfficialEventConverter(adapter_model.EventConverter):
),
special_title='',
join_timestamp=int(
datetime.datetime.strptime(
event.member.joined_at, '%Y-%m-%dT%H:%M:%S%z'
).timestamp()
datetime.datetime.strptime(event.member.joined_at, '%Y-%m-%dT%H:%M:%S%z').timestamp()
),
last_speak_timestamp=datetime.datetime.now().timestamp(),
mute_time_remaining=0,
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
).timestamp()
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()),
)
elif isinstance(event, botpy_message.DirectMessage): # 频道私聊,转私聊事件
return platform_events.FriendMessage(
@@ -320,14 +294,8 @@ class OfficialEventConverter(adapter_model.EventConverter):
nickname=event.author.username,
remark=event.author.username,
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
).timestamp()
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()),
)
elif isinstance(event, botpy_message.GroupMessage): # 群聊,转群聊事件
author_member_id = event.author.member_openid
@@ -347,14 +315,8 @@ class OfficialEventConverter(adapter_model.EventConverter):
last_speak_timestamp=datetime.datetime.now().timestamp(),
mute_time_remaining=0,
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
).timestamp()
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()),
)
elif isinstance(event, botpy_message.C2CMessage): # 私聊,转私聊事件
user_id_alter = event.author.user_openid
@@ -365,14 +327,8 @@ class OfficialEventConverter(adapter_model.EventConverter):
nickname=user_id_alter,
remark=user_id_alter,
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
).timestamp()
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()),
)
@@ -420,9 +376,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
self.bot = botpy.Client(intents=intents)
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
message_list = self.message_converter.yiri2target(message)
for msg in message_list:
@@ -468,22 +422,16 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
if quote_origin:
args['message_reference'] = botpy_message_type.Reference(
message_id=cached_message_ids[
str(message_source.message_chain.message_id)
]
message_id=cached_message_ids[str(message_source.message_chain.message_id)]
)
if isinstance(message_source, platform_events.GroupMessage):
args['channel_id'] = str(message_source.sender.group.id)
args['msg_id'] = cached_message_ids[
str(message_source.message_chain.message_id)
]
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
await self.bot.api.post_message(**args)
elif isinstance(message_source, platform_events.FriendMessage):
args['guild_id'] = str(message_source.sender.id)
args['msg_id'] = cached_message_ids[
str(message_source.message_chain.message_id)
]
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
await self.bot.api.post_dms(**args)
elif isinstance(message_source, OfficialGroupMessage):
if 'file_image' in args: # 暂不支持发送文件图片
@@ -502,9 +450,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
args['media'] = uploadMedia
args['msg_type'] = 7
args['msg_id'] = cached_message_ids[
str(message_source.message_chain.message_id)
]
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
args['msg_seq'] = self.group_msg_seq
self.group_msg_seq += 1
@@ -523,9 +469,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
args['media'] = uploadMedia
args['msg_type'] = 7
args['msg_id'] = cached_message_ids[
str(message_source.message_chain.message_id)
]
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
args['msg_seq'] = self.c2c_msg_seq
self.c2c_msg_seq += 1
@@ -538,9 +482,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter_model.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
):
try:
@@ -563,9 +505,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter_model.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
):
delattr(self.bot, event_handler_mapping[event_type])

View File

@@ -35,13 +35,9 @@ class QQOfficialMessageConverter(adapter.MessageConverter):
@staticmethod
async def target2yiri(message: str, message_id: str, pic_url: str, content_type):
yiri_msg_list = []
yiri_msg_list.append(
platform_message.Source(id=message_id, time=datetime.datetime.now())
)
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
if pic_url is not None:
base64_url = await image.get_qq_official_image_base64(
pic_url=pic_url, content_type=content_type
)
base64_url = await image.get_qq_official_image_base64(pic_url=pic_url, content_type=content_type)
yiri_msg_list.append(platform_message.Image(base64=base64_url))
yiri_msg_list.append(platform_message.Plain(text=message))
@@ -75,11 +71,7 @@ class QQOfficialEventConverter(adapter.EventConverter):
return platform_events.FriendMessage(
sender=friend,
message_chain=yiri_chain,
time=int(
datetime.datetime.strptime(
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
).timestamp()
),
time=int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp()),
source_platform_object=event,
)
@@ -89,9 +81,7 @@ class QQOfficialEventConverter(adapter.EventConverter):
nickname=event.t,
remark='',
)
return platform_events.FriendMessage(
sender=friend, message_chain=yiri_chain, source_platform_object=event
)
return platform_events.FriendMessage(sender=friend, message_chain=yiri_chain, source_platform_object=event)
if event.t == 'GROUP_AT_MESSAGE_CREATE':
yiri_chain.insert(0, platform_message.At(target='justbot'))
@@ -109,11 +99,7 @@ class QQOfficialEventConverter(adapter.EventConverter):
last_speak_timestamp=0,
mute_time_remaining=0,
)
time = int(
datetime.datetime.strptime(
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
).timestamp()
)
time = int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp())
return platform_events.GroupMessage(
sender=sender,
message_chain=yiri_chain,
@@ -136,11 +122,7 @@ class QQOfficialEventConverter(adapter.EventConverter):
last_speak_timestamp=0,
mute_time_remaining=0,
)
time = int(
datetime.datetime.strptime(
event.timestamp, '%Y-%m-%dT%H:%M:%S%z'
).timestamp()
)
time = int(datetime.datetime.strptime(event.timestamp, '%Y-%m-%dT%H:%M:%S%z').timestamp())
return platform_events.GroupMessage(
sender=sender,
message_chain=yiri_chain,
@@ -167,9 +149,7 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
]
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise ParamNotEnoughError(
'QQ官方机器人缺少相关配置项请查看文档或联系管理员'
)
raise ParamNotEnoughError('QQ官方机器人缺少相关配置项请查看文档或联系管理员')
self.bot = QQOfficialClient(
app_id=config['appid'],
@@ -229,24 +209,18 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
qq_official_event.d_id,
)
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
pass
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
async def on_message(event: QQOfficialEvent):
self.bot_account_id = 'justbot'
try:
return await callback(
await self.event_converter.target2yiri(event), self
)
return await callback(await self.event_converter.target2yiri(event), self)
except Exception:
traceback.print_exc()
@@ -274,8 +248,6 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
def unregister_listener(
self,
event_type: type,
callback: typing.Callable[
[platform_events.Event, MessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
return super().unregister_listener(event_type, callback)

Some files were not shown because too many files have changed in this diff Show More