style: introduce ruff as linter and formatter (#1356)

* style: remove necessary imports

* style: fix F841

* style: fix F401

* style: fix F811

* style: fix E402

* style: fix E721

* style: fix E722

* style: fix E722

* style: fix F541

* style: ruff format

* style: all passed

* style: add ruff in deps

* style: more ignores in ruff.toml

* style: add pre-commit
This commit is contained in:
Junyan Qin (Chin)
2025-04-29 17:24:07 +08:00
committed by GitHub
parent 09e70d70e9
commit 209f16af76
240 changed files with 5307 additions and 4689 deletions
+1
View File
@@ -29,6 +29,7 @@ qcapi
claude.json claude.json
bard.json bard.json
/*yaml /*yaml
!.pre-commit-config.yaml
!components.yaml !components.yaml
!/docker-compose.yaml !/docker-compose.yaml
data/labels/instance_id.json data/labels/instance_id.json
+9
View File
@@ -0,0 +1,9 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.11.7
hooks:
# Run the linter.
- id: ruff
# Run the formatter.
- id: ruff-format
+4 -2
View File
@@ -1,2 +1,4 @@
from .v1 import client from .v1 import client as client
from .v1 import errors from .v1 import errors as errors
__all__ = ['client', 'errors']
+18 -9
View File
@@ -8,25 +8,33 @@ import json
class TestDifyClient: class TestDifyClient:
async def test_chat_messages(self): async def test_chat_messages(self):
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL")) cln = client.AsyncDifyServiceClient(
api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL')
)
async for chunk in cln.chat_messages(inputs={}, query="调用工具查看现在几点?", user="test"): async for chunk in cln.chat_messages(
inputs={}, query='调用工具查看现在几点?', user='test'
):
print(json.dumps(chunk, ensure_ascii=False, indent=4)) print(json.dumps(chunk, ensure_ascii=False, indent=4))
async def test_upload_file(self): async def test_upload_file(self):
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL")) cln = client.AsyncDifyServiceClient(
api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL')
)
file_bytes = open("img.png", "rb").read() file_bytes = open('img.png', 'rb').read()
print(type(file_bytes)) print(type(file_bytes))
file = ("img2.png", file_bytes, "image/png") file = ('img2.png', file_bytes, 'image/png')
resp = await cln.upload_file(file=file, user="test") resp = await cln.upload_file(file=file, user='test')
print(json.dumps(resp, ensure_ascii=False, indent=4)) print(json.dumps(resp, ensure_ascii=False, indent=4))
async def test_workflow_run(self): async def test_workflow_run(self):
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL")) cln = client.AsyncDifyServiceClient(
api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL')
)
# resp = await cln.workflow_run(inputs={}, user="test") # resp = await cln.workflow_run(inputs={}, user="test")
# # print(json.dumps(resp, ensure_ascii=False, indent=4)) # # print(json.dumps(resp, ensure_ascii=False, indent=4))
@@ -34,11 +42,12 @@ class TestDifyClient:
chunks = [] chunks = []
ignored_events = ['text_chunk'] ignored_events = ['text_chunk']
async for chunk in cln.workflow_run(inputs={}, user="test"): async for chunk in cln.workflow_run(inputs={}, user='test'):
if chunk['event'] in ignored_events: if chunk['event'] in ignored_events:
continue continue
chunks.append(chunk) chunks.append(chunk)
print(json.dumps(chunks, ensure_ascii=False, indent=4)) print(json.dumps(chunks, ensure_ascii=False, indent=4))
if __name__ == "__main__":
if __name__ == '__main__':
asyncio.run(TestDifyClient().test_chat_messages()) asyncio.run(TestDifyClient().test_chat_messages())
+41 -36
View File
@@ -16,7 +16,7 @@ class AsyncDifyServiceClient:
def __init__( def __init__(
self, self,
api_key: str, api_key: str,
base_url: str = "https://api.dify.ai/v1", base_url: str = 'https://api.dify.ai/v1',
) -> None: ) -> None:
self.api_key = api_key self.api_key = api_key
self.base_url = base_url self.base_url = base_url
@@ -26,14 +26,14 @@ class AsyncDifyServiceClient:
inputs: dict[str, typing.Any], inputs: dict[str, typing.Any],
query: str, query: str,
user: str, user: str,
response_mode: str = "streaming", # 当前不支持 blocking response_mode: str = 'streaming', # 当前不支持 blocking
conversation_id: str = "", conversation_id: str = '',
files: list[dict[str, typing.Any]] = [], files: list[dict[str, typing.Any]] = [],
timeout: float = 30.0, timeout: float = 30.0,
) -> typing.AsyncGenerator[dict[str, typing.Any], None]: ) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
"""发送消息""" """发送消息"""
if response_mode != "streaming": if response_mode != 'streaming':
raise DifyAPIError("当前仅支持 streaming 模式") raise DifyAPIError('当前仅支持 streaming 模式')
async with httpx.AsyncClient( async with httpx.AsyncClient(
base_url=self.base_url, base_url=self.base_url,
@@ -41,61 +41,66 @@ class AsyncDifyServiceClient:
timeout=timeout, timeout=timeout,
) as client: ) as client:
async with client.stream( async with client.stream(
"POST", 'POST',
"/chat-messages", '/chat-messages',
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}, headers={
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json',
},
json={ json={
"inputs": inputs, 'inputs': inputs,
"query": query, 'query': query,
"user": user, 'user': user,
"response_mode": response_mode, 'response_mode': response_mode,
"conversation_id": conversation_id, 'conversation_id': conversation_id,
"files": files, 'files': files,
}, },
) as r: ) as r:
async for chunk in r.aiter_lines(): async for chunk in r.aiter_lines():
if r.status_code != 200: if r.status_code != 200:
raise DifyAPIError(f"{r.status_code} {chunk}") raise DifyAPIError(f'{r.status_code} {chunk}')
if chunk.strip() == "": if chunk.strip() == '':
continue continue
if chunk.startswith("data:"): if chunk.startswith('data:'):
yield json.loads(chunk[5:]) yield json.loads(chunk[5:])
async def workflow_run( async def workflow_run(
self, self,
inputs: dict[str, typing.Any], inputs: dict[str, typing.Any],
user: str, user: str,
response_mode: str = "streaming", # 当前不支持 blocking response_mode: str = 'streaming', # 当前不支持 blocking
files: list[dict[str, typing.Any]] = [], files: list[dict[str, typing.Any]] = [],
timeout: float = 30.0, timeout: float = 30.0,
) -> typing.AsyncGenerator[dict[str, typing.Any], None]: ) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
"""运行工作流""" """运行工作流"""
if response_mode != "streaming": if response_mode != 'streaming':
raise DifyAPIError("当前仅支持 streaming 模式") raise DifyAPIError('当前仅支持 streaming 模式')
async with httpx.AsyncClient( async with httpx.AsyncClient(
base_url=self.base_url, base_url=self.base_url,
trust_env=True, trust_env=True,
timeout=timeout, timeout=timeout,
) as client: ) as client:
async with client.stream( async with client.stream(
"POST", 'POST',
"/workflows/run", '/workflows/run',
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}, headers={
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json',
},
json={ json={
"inputs": inputs, 'inputs': inputs,
"user": user, 'user': user,
"response_mode": response_mode, 'response_mode': response_mode,
"files": files, 'files': files,
}, },
) as r: ) as r:
async for chunk in r.aiter_lines(): async for chunk in r.aiter_lines():
if r.status_code != 200: if r.status_code != 200:
raise DifyAPIError(f"{r.status_code} {chunk}") raise DifyAPIError(f'{r.status_code} {chunk}')
if chunk.strip() == "": if chunk.strip() == '':
continue continue
if chunk.startswith("data:"): if chunk.startswith('data:'):
yield json.loads(chunk[5:]) yield json.loads(chunk[5:])
async def upload_file( async def upload_file(
@@ -112,15 +117,15 @@ class AsyncDifyServiceClient:
) as client: ) as client:
# multipart/form-data # multipart/form-data
response = await client.post( response = await client.post(
"/files/upload", '/files/upload',
headers={"Authorization": f"Bearer {self.api_key}"}, headers={'Authorization': f'Bearer {self.api_key}'},
files={ files={
"file": file, 'file': file,
"user": (None, user), 'user': (None, user),
}, },
) )
if response.status_code != 201: if response.status_code != 201:
raise DifyAPIError(f"{response.status_code} {response.text}") raise DifyAPIError(f'{response.status_code} {response.text}')
return response.json() return response.json()
+3 -3
View File
@@ -7,11 +7,11 @@ import os
class TestDifyClient: class TestDifyClient:
async def test_chat_messages(self): async def test_chat_messages(self):
cln = client.DifyClient(api_key=os.getenv("DIFY_API_KEY")) cln = client.DifyClient(api_key=os.getenv('DIFY_API_KEY'))
resp = await cln.chat_messages(inputs={}, query="Who are you?", user_id="test") resp = await cln.chat_messages(inputs={}, query='Who are you?', user_id='test')
print(resp) print(resp)
if __name__ == "__main__": if __name__ == '__main__':
asyncio.run(TestDifyClient().test_chat_messages()) asyncio.run(TestDifyClient().test_chat_messages())
+4 -1
View File
@@ -1,8 +1,8 @@
import asyncio import asyncio
import json
import dingtalk_stream import dingtalk_stream
from dingtalk_stream import AckMessage from dingtalk_stream import AckMessage
class EchoTextHandler(dingtalk_stream.ChatbotHandler): class EchoTextHandler(dingtalk_stream.ChatbotHandler):
def __init__(self, client): def __init__(self, client):
self.msg_id = '' self.msg_id = ''
@@ -10,6 +10,7 @@ class EchoTextHandler(dingtalk_stream.ChatbotHandler):
self.client = client # 用于更新 DingTalkClient 中的 incoming_message self.client = client # 用于更新 DingTalkClient 中的 incoming_message
"""处理钉钉消息""" """处理钉钉消息"""
async def process(self, callback: dingtalk_stream.CallbackMessage): async def process(self, callback: dingtalk_stream.CallbackMessage):
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data) incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
if incoming_message.message_id != self.msg_id: if incoming_message.message_id != self.msg_id:
@@ -26,6 +27,8 @@ class EchoTextHandler(dingtalk_stream.ChatbotHandler):
return self.incoming_message return self.incoming_message
async def get_dingtalk_client(client_id, client_secret): async def get_dingtalk_client(client_id, client_secret):
from api import DingTalkClient # 延迟导入,避免循环导入 from api import DingTalkClient # 延迟导入,避免循环导入
return DingTalkClient(client_id, client_secret) return DingTalkClient(client_id, client_secret)
+70 -83
View File
@@ -10,7 +10,9 @@ import traceback
class DingTalkClient: class DingTalkClient:
def __init__(self, client_id: str, client_secret: str,robot_name:str,robot_code:str): def __init__(
self, client_id: str, client_secret: str, robot_name: str, robot_code: str
):
"""初始化 WebSocket 连接并自动启动""" """初始化 WebSocket 连接并自动启动"""
self.credential = dingtalk_stream.Credential(client_id, client_secret) self.credential = dingtalk_stream.Credential(client_id, client_secret)
self.client = dingtalk_stream.DingTalkStreamClient(self.credential) self.client = dingtalk_stream.DingTalkStreamClient(self.credential)
@@ -18,38 +20,32 @@ class DingTalkClient:
self.secret = client_secret self.secret = client_secret
# 在 DingTalkClient 中传入自己作为参数,避免循环导入 # 在 DingTalkClient 中传入自己作为参数,避免循环导入
self.EchoTextHandler = EchoTextHandler(self) self.EchoTextHandler = EchoTextHandler(self)
self.client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler) self.client.register_callback_handler(
dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler
)
self._message_handlers = { self._message_handlers = {
"example":[], 'example': [],
} }
self.access_token = '' self.access_token = ''
self.robot_name = robot_name self.robot_name = robot_name
self.robot_code = robot_code self.robot_code = robot_code
self.access_token_expiry_time = '' self.access_token_expiry_time = ''
async def get_access_token(self): async def get_access_token(self):
url = "https://api.dingtalk.com/v1.0/oauth2/accessToken" url = 'https://api.dingtalk.com/v1.0/oauth2/accessToken'
headers = { headers = {'Content-Type': 'application/json'}
"Content-Type": "application/json" data = {'appKey': self.key, 'appSecret': self.secret}
}
data = {
"appKey": self.key,
"appSecret": self.secret
}
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
try: try:
response = await client.post(url,json=data,headers=headers) response = await client.post(url, json=data, headers=headers)
if response.status_code == 200: if response.status_code == 200:
response_data = response.json() response_data = response.json()
self.access_token = response_data.get("accessToken") self.access_token = response_data.get('accessToken')
expires_in = int(response_data.get("expireIn",7200)) expires_in = int(response_data.get('expireIn', 7200))
self.access_token_expiry_time = time.time() + expires_in - 60 self.access_token_expiry_time = time.time() + expires_in - 60
except Exception as e: except Exception as e:
raise Exception(e) raise Exception(e)
async def is_token_expired(self): async def is_token_expired(self):
"""检查token是否过期""" """检查token是否过期"""
if self.access_token_expiry_time is None: if self.access_token_expiry_time is None:
@@ -61,62 +57,53 @@ class DingTalkClient:
return False return False
return bool(self.access_token and self.access_token.strip()) return bool(self.access_token and self.access_token.strip())
async def download_image(self,download_code:str): async def download_image(self, download_code: str):
if not await self.check_access_token(): if not await self.check_access_token():
await self.get_access_token() await self.get_access_token()
url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download' url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download'
params = { params = {'downloadCode': download_code, 'robotCode': self.robot_code}
"downloadCode":download_code, headers = {'x-acs-dingtalk-access-token': self.access_token}
"robotCode":self.robot_code
}
headers ={
"x-acs-dingtalk-access-token": self.access_token
}
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post(url, headers=headers, json=params) response = await client.post(url, headers=headers, json=params)
if response.status_code == 200: if response.status_code == 200:
result = response.json() result = response.json()
download_url = result.get("downloadUrl") download_url = result.get('downloadUrl')
else: else:
raise Exception(f"Error: {response.status_code}, {response.text}") raise Exception(f'Error: {response.status_code}, {response.text}')
if download_url: if download_url:
return await self.download_url_to_base64(download_url) return await self.download_url_to_base64(download_url)
async def download_url_to_base64(self,download_url): async def download_url_to_base64(self, download_url):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get(download_url) response = await client.get(download_url)
if response.status_code == 200: if response.status_code == 200:
file_bytes = response.content file_bytes = response.content
base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式 base64_str = base64.b64encode(file_bytes).decode(
'utf-8'
) # 返回字符串格式
return base64_str return base64_str
else: else:
raise Exception("获取文件失败") raise Exception('获取文件失败')
async def get_audio_url(self,download_code:str): async def get_audio_url(self, download_code: str):
if not await self.check_access_token(): if not await self.check_access_token():
await self.get_access_token() await self.get_access_token()
url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download' url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download'
params = { params = {'downloadCode': download_code, 'robotCode': self.robot_code}
"downloadCode":download_code, headers = {'x-acs-dingtalk-access-token': self.access_token}
"robotCode":self.robot_code
}
headers ={
"x-acs-dingtalk-access-token": self.access_token
}
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post(url, headers=headers, json=params) response = await client.post(url, headers=headers, json=params)
if response.status_code == 200: if response.status_code == 200:
result = response.json() result = response.json()
download_url = result.get("downloadUrl") download_url = result.get('downloadUrl')
if download_url: if download_url:
return await self.download_url_to_base64(download_url) return await self.download_url_to_base64(download_url)
else: else:
raise Exception("获取音频失败") raise Exception('获取音频失败')
else: else:
raise Exception(f"Error: {response.status_code}, {response.text}") raise Exception(f'Error: {response.status_code}, {response.text}')
async def update_incoming_message(self, message): async def update_incoming_message(self, message):
"""异步更新 DingTalkClient 中的 incoming_message""" """异步更新 DingTalkClient 中的 incoming_message"""
@@ -126,23 +113,20 @@ class DingTalkClient:
if event: if event:
await self._handle_message(event) await self._handle_message(event)
async def send_message(self, content: str, incoming_message):
async def send_message(self,content:str,incoming_message): self.EchoTextHandler.reply_text(content, incoming_message)
self.EchoTextHandler.reply_text(content,incoming_message)
async def get_incoming_message(self): async def get_incoming_message(self):
"""获取收到的消息""" """获取收到的消息"""
return await self.EchoTextHandler.get_incoming_message() return await self.EchoTextHandler.get_incoming_message()
def on_message(self, msg_type: str): def on_message(self, msg_type: str):
def decorator(func: Callable[[DingTalkEvent], None]): def decorator(func: Callable[[DingTalkEvent], None]):
if msg_type not in self._message_handlers: if msg_type not in self._message_handlers:
self._message_handlers[msg_type] = [] self._message_handlers[msg_type] = []
self._message_handlers[msg_type].append(func) self._message_handlers[msg_type].append(func)
return func return func
return decorator return decorator
async def _handle_message(self, event: DingTalkEvent): async def _handle_message(self, event: DingTalkEvent):
@@ -154,40 +138,44 @@ class DingTalkClient:
for handler in self._message_handlers[msg_type]: for handler in self._message_handlers[msg_type]:
await handler(event) await handler(event)
async def get_message(
async def get_message(self,incoming_message:dingtalk_stream.chatbot.ChatbotMessage): self, incoming_message: dingtalk_stream.chatbot.ChatbotMessage
):
try: try:
# print(json.dumps(incoming_message.to_dict(), indent=4, ensure_ascii=False)) # print(json.dumps(incoming_message.to_dict(), indent=4, ensure_ascii=False))
message_data = { message_data = {
"IncomingMessage":incoming_message, 'IncomingMessage': incoming_message,
} }
if str(incoming_message.conversation_type) == '1': if str(incoming_message.conversation_type) == '1':
message_data["conversation_type"] = 'FriendMessage' message_data['conversation_type'] = 'FriendMessage'
elif str(incoming_message.conversation_type) == '2': elif str(incoming_message.conversation_type) == '2':
message_data["conversation_type"] = 'GroupMessage' message_data['conversation_type'] = 'GroupMessage'
if incoming_message.message_type == 'richText': if incoming_message.message_type == 'richText':
data = incoming_message.rich_text_content.to_dict() data = incoming_message.rich_text_content.to_dict()
for item in data['richText']: for item in data['richText']:
if 'text' in item: if 'text' in item:
message_data["Content"] = item['text'] message_data['Content'] = item['text']
if incoming_message.get_image_list()[0]: if incoming_message.get_image_list()[0]:
message_data["Picture"] = await self.download_image(incoming_message.get_image_list()[0]) message_data['Picture'] = await self.download_image(
message_data["Type"] = 'text' incoming_message.get_image_list()[0]
)
message_data['Type'] = 'text'
elif incoming_message.message_type == 'text': elif incoming_message.message_type == 'text':
message_data['Content'] = incoming_message.get_text_list()[0] message_data['Content'] = incoming_message.get_text_list()[0]
message_data["Type"] = 'text' message_data['Type'] = 'text'
elif incoming_message.message_type == 'picture': elif incoming_message.message_type == 'picture':
message_data['Picture'] = await self.download_image(incoming_message.get_image_list()[0]) message_data['Picture'] = await self.download_image(
incoming_message.get_image_list()[0]
)
message_data['Type'] = 'image' message_data['Type'] = 'image'
elif incoming_message.message_type == 'audio': elif incoming_message.message_type == 'audio':
message_data['Audio'] = await self.get_audio_url(incoming_message.to_dict()['content']['downloadCode']) message_data['Audio'] = await self.get_audio_url(
incoming_message.to_dict()['content']['downloadCode']
)
message_data['Type'] = 'audio' message_data['Type'] = 'audio'
@@ -199,50 +187,49 @@ class DingTalkClient:
return message_data return message_data
async def send_proactive_message_to_one(self,target_id:str,content:str): async def send_proactive_message_to_one(self, target_id: str, content: str):
if not await self.check_access_token(): if not await self.check_access_token():
await self.get_access_token() await self.get_access_token()
url = 'https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend' url = 'https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend'
headers ={ headers = {
"x-acs-dingtalk-access-token":self.access_token, 'x-acs-dingtalk-access-token': self.access_token,
"Content-Type":"application/json", 'Content-Type': 'application/json',
} }
data ={ data = {
"robotCode":self.robot_code, 'robotCode': self.robot_code,
"userIds":[target_id], 'userIds': [target_id],
"msgKey": "sampleText", 'msgKey': 'sampleText',
"msgParam": json.dumps({"content":content}), 'msgParam': json.dumps({'content': content}),
} }
try: try:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post(url,headers=headers,json=data) await client.post(url, headers=headers, json=data)
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
async def send_proactive_message_to_group(self, target_id: str, content: str):
async def send_proactive_message_to_group(self,target_id:str,content:str):
if not await self.check_access_token(): if not await self.check_access_token():
await self.get_access_token() await self.get_access_token()
url = 'https://api.dingtalk.com/v1.0/robot/groupMessages/send' url = 'https://api.dingtalk.com/v1.0/robot/groupMessages/send'
headers ={ headers = {
"x-acs-dingtalk-access-token":self.access_token, 'x-acs-dingtalk-access-token': self.access_token,
"Content-Type":"application/json", 'Content-Type': 'application/json',
} }
data ={ data = {
"robotCode":self.robot_code, 'robotCode': self.robot_code,
"openConversationId":target_id, 'openConversationId': target_id,
"msgKey": "sampleText", 'msgKey': 'sampleText',
"msgParam": json.dumps({"content":content}), 'msgParam': json.dumps({'content': content}),
} }
try: try:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post(url,headers=headers,json=data) await client.post(url, headers=headers, json=data)
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
+10 -12
View File
@@ -1,41 +1,39 @@
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
import dingtalk_stream import dingtalk_stream
class DingTalkEvent(dict): class DingTalkEvent(dict):
@staticmethod @staticmethod
def from_payload(payload: Dict[str, Any]) -> Optional["DingTalkEvent"]: def from_payload(payload: Dict[str, Any]) -> Optional['DingTalkEvent']:
try: try:
event = DingTalkEvent(payload) event = DingTalkEvent(payload)
return event return event
except KeyError: except KeyError:
return None return None
@property @property
def content(self): def content(self):
return self.get("Content","") return self.get('Content', '')
@property @property
def incoming_message(self) -> Optional["dingtalk_stream.chatbot.ChatbotMessage"]: def incoming_message(self) -> Optional['dingtalk_stream.chatbot.ChatbotMessage']:
return self.get("IncomingMessage") return self.get('IncomingMessage')
@property @property
def type(self): def type(self):
return self.get("Type","") return self.get('Type', '')
@property @property
def picture(self): def picture(self):
return self.get("Picture","") return self.get('Picture', '')
@property @property
def audio(self): def audio(self):
return self.get("Audio","") return self.get('Audio', '')
@property @property
def conversation(self): def conversation(self):
return self.get("conversation_type","") return self.get('conversation_type', '')
def __getattr__(self, key: str) -> Optional[Any]: def __getattr__(self, key: str) -> Optional[Any]:
""" """
@@ -66,4 +64,4 @@ class DingTalkEvent(dict):
Returns: Returns:
str: 字符串表示。 str: 字符串表示。
""" """
return f"<DingTalkEvent {super().__repr__()}>" return f'<DingTalkEvent {super().__repr__()}>'
+115 -105
View File
@@ -1,20 +1,14 @@
# 微信公众号的加解密算法与企业微信一样,所以直接使用企业微信的加解密算法文件 # 微信公众号的加解密算法与企业微信一样,所以直接使用企业微信的加解密算法文件
from collections import deque
import time import time
import traceback import traceback
from ..wecom_api.WXBizMsgCrypt3 import WXBizMsgCrypt from ..wecom_api.WXBizMsgCrypt3 import WXBizMsgCrypt
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from quart import Quart,request from quart import Quart, request
import hashlib import hashlib
from typing import Callable, Dict, Any from typing import Callable
from .oaevent import OAEvent from .oaevent import OAEvent
import httpx
import asyncio import asyncio
import time
import xml.etree.ElementTree as ET
from pkg.platform.sources import officialaccount as oa
xml_template = """ xml_template = """
@@ -28,9 +22,8 @@ xml_template = """
""" """
class OAClient(): class OAClient:
def __init__(self, token: str, EncodingAESKey: str, AppID: str, Appsecret: str):
def __init__(self,token:str,EncodingAESKey:str,AppID:str,Appsecret:str):
self.token = token self.token = token
self.aes = EncodingAESKey self.aes = EncodingAESKey
self.appid = AppID self.appid = AppID
@@ -38,65 +31,71 @@ class OAClient():
self.base_url = 'https://api.weixin.qq.com' self.base_url = 'https://api.weixin.qq.com'
self.access_token = '' self.access_token = ''
self.app = Quart(__name__) self.app = Quart(__name__)
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']) self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['GET', 'POST'],
)
self._message_handlers = { self._message_handlers = {
"example":[], 'example': [],
} }
self.access_token_expiry_time = None self.access_token_expiry_time = None
self.msg_id_map = {} self.msg_id_map = {}
self.generated_content = {} self.generated_content = {}
async def handle_callback_request(self): async def handle_callback_request(self):
try: try:
# 每隔100毫秒查询是否生成ai回答 # 每隔100毫秒查询是否生成ai回答
start_time = time.time() start_time = time.time()
signature = request.args.get("signature", "") signature = request.args.get('signature', '')
timestamp = request.args.get("timestamp", "") timestamp = request.args.get('timestamp', '')
nonce = request.args.get("nonce", "") nonce = request.args.get('nonce', '')
echostr = request.args.get("echostr", "") echostr = request.args.get('echostr', '')
msg_signature = request.args.get("msg_signature","") msg_signature = request.args.get('msg_signature', '')
if msg_signature is None: if msg_signature is None:
raise Exception("msg_signature不在请求体中") raise Exception('msg_signature不在请求体中')
if request.method == 'GET': if request.method == 'GET':
# 校验签名 # 校验签名
check_str = "".join(sorted([self.token, timestamp, nonce])) check_str = ''.join(sorted([self.token, timestamp, nonce]))
check_signature = hashlib.sha1(check_str.encode("utf-8")).hexdigest() check_signature = hashlib.sha1(check_str.encode('utf-8')).hexdigest()
if check_signature == signature: if check_signature == signature:
return echostr # 验证成功返回echostr return echostr # 验证成功返回echostr
else: else:
raise Exception("拒绝请求") raise Exception('拒绝请求')
elif request.method == "POST": elif request.method == 'POST':
encryt_msg = await request.data encryt_msg = await request.data
wxcpt = WXBizMsgCrypt(self.token,self.aes,self.appid) wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
ret,xml_msg = wxcpt.DecryptMsg(encryt_msg,msg_signature,timestamp,nonce) ret, xml_msg = wxcpt.DecryptMsg(
encryt_msg, msg_signature, timestamp, nonce
)
xml_msg = xml_msg.decode('utf-8') xml_msg = xml_msg.decode('utf-8')
if ret != 0: if ret != 0:
raise Exception("消息解密失败") raise Exception('消息解密失败')
message_data = await self.get_message(xml_msg) message_data = await self.get_message(xml_msg)
if message_data : if message_data:
event = OAEvent.from_payload(message_data) event = OAEvent.from_payload(message_data)
if event: if event:
await self._handle_message(event) await self._handle_message(event)
root = ET.fromstring(xml_msg) root = ET.fromstring(xml_msg)
from_user = root.find("FromUserName").text # 发送者 from_user = root.find('FromUserName').text # 发送者
to_user = root.find("ToUserName").text # 机器人 to_user = root.find('ToUserName').text # 机器人
timeout = 4.80 timeout = 4.80
interval = 0.1 interval = 0.1
while True: while True:
content = self.generated_content.pop(message_data["MsgId"], None) content = self.generated_content.pop(message_data['MsgId'], None)
if content: if content:
response_xml = xml_template.format( response_xml = xml_template.format(
to_user=from_user, to_user=from_user,
from_user=to_user, from_user=to_user,
create_time=int(time.time()), create_time=int(time.time()),
content = content content=content,
) )
return response_xml return response_xml
@@ -106,53 +105,56 @@ class OAClient():
await asyncio.sleep(interval) await asyncio.sleep(interval)
if self.msg_id_map.get(message_data["MsgId"], 1) == 3: if self.msg_id_map.get(message_data['MsgId'], 1) == 3:
# response_xml = xml_template.format( # response_xml = xml_template.format(
# to_user=from_user, # to_user=from_user,
# from_user=to_user, # from_user=to_user,
# create_time=int(time.time()), # create_time=int(time.time()),
# content = "请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。" # content = "请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。"
# ) # )
print("请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。") print(
'请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。'
)
return '' return ''
except Exception as e: except Exception:
traceback.print_exc() traceback.print_exc()
async def get_message(self, xml_msg: str): async def get_message(self, xml_msg: str):
root = ET.fromstring(xml_msg) root = ET.fromstring(xml_msg)
message_data = { message_data = {
"ToUserName": root.find("ToUserName").text, 'ToUserName': root.find('ToUserName').text,
"FromUserName": root.find("FromUserName").text, 'FromUserName': root.find('FromUserName').text,
"CreateTime": int(root.find("CreateTime").text), 'CreateTime': int(root.find('CreateTime').text),
"MsgType": root.find("MsgType").text, 'MsgType': root.find('MsgType').text,
"Content": root.find("Content").text if root.find("Content") is not None else None, 'Content': root.find('Content').text
"MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None, if root.find('Content') is not None
else None,
'MsgId': int(root.find('MsgId').text)
if root.find('MsgId') is not None
else None,
} }
return message_data return message_data
async def run_task(self, host: str, port: int, *args, **kwargs): async def run_task(self, host: str, port: int, *args, **kwargs):
""" """
启动 Quart 应用。 启动 Quart 应用。
""" """
await self.app.run_task(host=host, port=port, *args, **kwargs) await self.app.run_task(host=host, port=port, *args, **kwargs)
def on_message(self, msg_type: str): def on_message(self, msg_type: str):
""" """
注册消息类型处理器。 注册消息类型处理器。
""" """
def decorator(func: Callable[[OAEvent], None]): def decorator(func: Callable[[OAEvent], None]):
if msg_type not in self._message_handlers: if msg_type not in self._message_handlers:
self._message_handlers[msg_type] = [] self._message_handlers[msg_type] = []
self._message_handlers[msg_type].append(func) self._message_handlers[msg_type].append(func)
return func return func
return decorator return decorator
async def _handle_message(self, event: OAEvent): async def _handle_message(self, event: OAEvent):
@@ -170,14 +172,19 @@ class OAClient():
for handler in self._message_handlers[msg_type]: for handler in self._message_handlers[msg_type]:
await handler(event) await handler(event)
async def set_message(self,msg_id:int,content:str): async def set_message(self, msg_id: int, content: str):
self.generated_content[msg_id] = content self.generated_content[msg_id] = content
class OAClientForLongerResponse:
class OAClientForLongerResponse(): def __init__(
self,
def __init__(self,token:str,EncodingAESKey:str,AppID:str,Appsecret:str,LoadingMessage:str): token: str,
EncodingAESKey: str,
AppID: str,
Appsecret: str,
LoadingMessage: str,
):
self.token = token self.token = token
self.aes = EncodingAESKey self.aes = EncodingAESKey
self.appid = AppID self.appid = AppID
@@ -185,9 +192,14 @@ class OAClientForLongerResponse():
self.base_url = 'https://api.weixin.qq.com' self.base_url = 'https://api.weixin.qq.com'
self.access_token = '' self.access_token = ''
self.app = Quart(__name__) self.app = Quart(__name__)
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']) self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['GET', 'POST'],
)
self._message_handlers = { self._message_handlers = {
"example":[], 'example': [],
} }
self.access_token_expiry_time = None self.access_token_expiry_time = None
self.loading_message = LoadingMessage self.loading_message = LoadingMessage
@@ -196,50 +208,55 @@ class OAClientForLongerResponse():
async def handle_callback_request(self): async def handle_callback_request(self):
try: try:
start_time = time.time() signature = request.args.get('signature', '')
signature = request.args.get("signature", "") timestamp = request.args.get('timestamp', '')
timestamp = request.args.get("timestamp", "") nonce = request.args.get('nonce', '')
nonce = request.args.get("nonce", "") echostr = request.args.get('echostr', '')
echostr = request.args.get("echostr", "") msg_signature = request.args.get('msg_signature', '')
msg_signature = request.args.get("msg_signature", "")
if msg_signature is None: if msg_signature is None:
raise Exception("msg_signature不在请求体中") raise Exception('msg_signature不在请求体中')
if request.method == 'GET': if request.method == 'GET':
check_str = "".join(sorted([self.token, timestamp, nonce])) check_str = ''.join(sorted([self.token, timestamp, nonce]))
check_signature = hashlib.sha1(check_str.encode("utf-8")).hexdigest() check_signature = hashlib.sha1(check_str.encode('utf-8')).hexdigest()
return echostr if check_signature == signature else "拒绝请求" return echostr if check_signature == signature else '拒绝请求'
elif request.method == "POST": elif request.method == 'POST':
encryt_msg = await request.data encryt_msg = await request.data
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid) wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce) ret, xml_msg = wxcpt.DecryptMsg(
encryt_msg, msg_signature, timestamp, nonce
)
xml_msg = xml_msg.decode('utf-8') xml_msg = xml_msg.decode('utf-8')
if ret != 0: if ret != 0:
raise Exception("消息解密失败") raise Exception('消息解密失败')
# 解析 XML # 解析 XML
root = ET.fromstring(xml_msg) root = ET.fromstring(xml_msg)
from_user = root.find("FromUserName").text from_user = root.find('FromUserName').text
to_user = root.find("ToUserName").text to_user = root.find('ToUserName').text
if self.msg_queue.get(from_user) and self.msg_queue[from_user][0]["content"]:
if (
self.msg_queue.get(from_user)
and self.msg_queue[from_user][0]['content']
):
queue_top = self.msg_queue[from_user].pop(0) queue_top = self.msg_queue[from_user].pop(0)
queue_content = queue_top["content"] queue_content = queue_top['content']
# 弹出用户消息 # 弹出用户消息
if self.user_msg_queue.get(from_user) and self.user_msg_queue[from_user]: if (
self.user_msg_queue.get(from_user)
and self.user_msg_queue[from_user]
):
self.user_msg_queue[from_user].pop(0) self.user_msg_queue[from_user].pop(0)
response_xml = xml_template.format( response_xml = xml_template.format(
to_user=from_user, to_user=from_user,
from_user=to_user, from_user=to_user,
create_time=int(time.time()), create_time=int(time.time()),
content=queue_content content=queue_content,
) )
return response_xml return response_xml
@@ -248,10 +265,13 @@ class OAClientForLongerResponse():
to_user=from_user, to_user=from_user,
from_user=to_user, from_user=to_user,
create_time=int(time.time()), create_time=int(time.time()),
content=self.loading_message content=self.loading_message,
) )
if self.user_msg_queue.get(from_user) and self.user_msg_queue[from_user][0]["content"]: if (
self.user_msg_queue.get(from_user)
and self.user_msg_queue[from_user][0]['content']
):
return response_xml return response_xml
else: else:
message_data = await self.get_message(xml_msg) message_data = await self.get_message(xml_msg)
@@ -259,54 +279,53 @@ class OAClientForLongerResponse():
if message_data: if message_data:
event = OAEvent.from_payload(message_data) event = OAEvent.from_payload(message_data)
if event: if event:
self.user_msg_queue.setdefault(from_user,[]).append( self.user_msg_queue.setdefault(from_user, []).append(
{ {
"content":event.message, 'content': event.message,
} }
) )
await self._handle_message(event) await self._handle_message(event)
return response_xml return response_xml
except Exception as e: except Exception:
traceback.print_exc() traceback.print_exc()
async def get_message(self, xml_msg: str): async def get_message(self, xml_msg: str):
root = ET.fromstring(xml_msg) root = ET.fromstring(xml_msg)
message_data = { message_data = {
"ToUserName": root.find("ToUserName").text, 'ToUserName': root.find('ToUserName').text,
"FromUserName": root.find("FromUserName").text, 'FromUserName': root.find('FromUserName').text,
"CreateTime": int(root.find("CreateTime").text), 'CreateTime': int(root.find('CreateTime').text),
"MsgType": root.find("MsgType").text, 'MsgType': root.find('MsgType').text,
"Content": root.find("Content").text if root.find("Content") is not None else None, 'Content': root.find('Content').text
"MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None, if root.find('Content') is not None
else None,
'MsgId': int(root.find('MsgId').text)
if root.find('MsgId') is not None
else None,
} }
return message_data return message_data
async def run_task(self, host: str, port: int, *args, **kwargs): async def run_task(self, host: str, port: int, *args, **kwargs):
""" """
启动 Quart 应用。 启动 Quart 应用。
""" """
await self.app.run_task(host=host, port=port, *args, **kwargs) await self.app.run_task(host=host, port=port, *args, **kwargs)
def on_message(self, msg_type: str): def on_message(self, msg_type: str):
""" """
注册消息类型处理器。 注册消息类型处理器。
""" """
def decorator(func: Callable[[OAEvent], None]): def decorator(func: Callable[[OAEvent], None]):
if msg_type not in self._message_handlers: if msg_type not in self._message_handlers:
self._message_handlers[msg_type] = [] self._message_handlers[msg_type] = []
self._message_handlers[msg_type].append(func) self._message_handlers[msg_type].append(func)
return func return func
return decorator return decorator
async def _handle_message(self, event: OAEvent): async def _handle_message(self, event: OAEvent):
@@ -319,22 +338,13 @@ class OAClientForLongerResponse():
for handler in self._message_handlers[msg_type]: for handler in self._message_handlers[msg_type]:
await handler(event) await handler(event)
async def set_message(self,from_user:int,message_id:int,content:str): async def set_message(self, from_user: int, message_id: int, content: str):
if from_user not in self.msg_queue: if from_user not in self.msg_queue:
self.msg_queue[from_user] = [] self.msg_queue[from_user] = []
self.msg_queue[from_user].append( self.msg_queue[from_user].append(
{ {
"msg_id":message_id, 'msg_id': message_id,
"content":content, 'content': content,
} }
) )
+14 -15
View File
@@ -9,7 +9,7 @@ class OAEvent(dict):
""" """
@staticmethod @staticmethod
def from_payload(payload: Dict[str, Any]) -> Optional["OAEvent"]: def from_payload(payload: Dict[str, Any]) -> Optional['OAEvent']:
""" """
从微信公众号事件数据构造 `WecomEvent` 对象。 从微信公众号事件数据构造 `WecomEvent` 对象。
@@ -34,14 +34,14 @@ class OAEvent(dict):
Returns: Returns:
str: 事件类型。 str: 事件类型。
""" """
return self.get("MsgType", "") return self.get('MsgType', '')
@property @property
def picurl(self) -> str: def picurl(self) -> str:
""" """
图片链接 图片链接
""" """
return self.get("PicUrl","") return self.get('PicUrl', '')
@property @property
def detail_type(self) -> str: def detail_type(self) -> str:
@@ -53,8 +53,8 @@ class OAEvent(dict):
Returns: Returns:
str: 事件详细类型。 str: 事件详细类型。
""" """
if self.type == "event": if self.type == 'event':
return self.get("Event", "") return self.get('Event', '')
return self.type return self.type
@property @property
@@ -65,15 +65,14 @@ class OAEvent(dict):
Returns: Returns:
str: 事件名。 str: 事件名。
""" """
return f"{self.type}.{self.detail_type}" return f'{self.type}.{self.detail_type}'
@property @property
def user_id(self) -> Optional[str]: def user_id(self) -> Optional[str]:
""" """
发送方账号 发送方账号
""" """
return self.get("FromUserName") return self.get('FromUserName')
@property @property
def receiver_id(self) -> Optional[str]: def receiver_id(self) -> Optional[str]:
@@ -83,7 +82,7 @@ class OAEvent(dict):
Returns: Returns:
Optional[str]: 接收者 ID。 Optional[str]: 接收者 ID。
""" """
return self.get("ToUserName") return self.get('ToUserName')
@property @property
def message_id(self) -> Optional[str]: def message_id(self) -> Optional[str]:
@@ -93,7 +92,7 @@ class OAEvent(dict):
Returns: Returns:
Optional[str]: 消息 ID。 Optional[str]: 消息 ID。
""" """
return self.get("MsgId") return self.get('MsgId')
@property @property
def message(self) -> Optional[str]: def message(self) -> Optional[str]:
@@ -103,7 +102,7 @@ class OAEvent(dict):
Returns: Returns:
Optional[str]: 消息内容。 Optional[str]: 消息内容。
""" """
return self.get("Content") return self.get('Content')
@property @property
def media_id(self) -> Optional[str]: def media_id(self) -> Optional[str]:
@@ -113,7 +112,7 @@ class OAEvent(dict):
Returns: Returns:
Optional[str]: 媒体文件 ID。 Optional[str]: 媒体文件 ID。
""" """
return self.get("MediaId") return self.get('MediaId')
@property @property
def timestamp(self) -> Optional[int]: def timestamp(self) -> Optional[int]:
@@ -123,7 +122,7 @@ class OAEvent(dict):
Returns: Returns:
Optional[int]: 时间戳。 Optional[int]: 时间戳。
""" """
return self.get("CreateTime") return self.get('CreateTime')
@property @property
def event_key(self) -> Optional[str]: def event_key(self) -> Optional[str]:
@@ -133,7 +132,7 @@ class OAEvent(dict):
Returns: Returns:
Optional[str]: 事件 Key。 Optional[str]: 事件 Key。
""" """
return self.get("EventKey") return self.get('EventKey')
def __getattr__(self, key: str) -> Optional[Any]: def __getattr__(self, key: str) -> Optional[Any]:
""" """
@@ -164,4 +163,4 @@ class OAEvent(dict):
Returns: Returns:
str: 字符串表示。 str: 字符串表示。
""" """
return f"<WecomEvent {super().__repr__()}>" return f'<WecomEvent {super().__repr__()}>'
+96 -104
View File
@@ -1,24 +1,16 @@
import time import time
from quart import request from quart import request
import base64
import binascii
import httpx import httpx
from quart import Quart from quart import Quart
import xml.etree.ElementTree as ET
from typing import Callable, Dict, Any from typing import Callable, Dict, Any
from pkg.platform.types import events as platform_events, message as platform_message from pkg.platform.types import events as platform_events
import aiofiles
from .qqofficialevent import QQOfficialEvent from .qqofficialevent import QQOfficialEvent
import json import json
import hmac
import base64
import hashlib
import traceback import traceback
from cryptography.hazmat.primitives.asymmetric import ed25519 from cryptography.hazmat.primitives.asymmetric import ed25519
from .qqofficialevent import QQOfficialEvent
def handle_validation(body: dict, bot_secret: str): def handle_validation(body: dict, bot_secret: str):
# bot正确的secert是32位的,此处仅为了适配演示demo # bot正确的secert是32位的,此处仅为了适配演示demo
while len(bot_secret) < 32: while len(bot_secret) < 32:
bot_secret = bot_secret * 2 bot_secret = bot_secret * 2
@@ -36,29 +28,26 @@ def handle_validation(body: dict, bot_secret: str):
signature_hex = signature.hex() signature_hex = signature.hex()
response = { response = {'plain_token': body['d']['plain_token'], 'signature': signature_hex}
"plain_token": body['d']['plain_token'],
"signature": signature_hex
}
return response return response
class QQOfficialClient: class QQOfficialClient:
def __init__(self, secret: str, token: str, app_id: str): def __init__(self, secret: str, token: str, app_id: str):
self.app = Quart(__name__) self.app = Quart(__name__)
self.app.add_url_rule( self.app.add_url_rule(
"/callback/command", '/callback/command',
"handle_callback", 'handle_callback',
self.handle_callback_request, self.handle_callback_request,
methods=["GET", "POST"], methods=['GET', 'POST'],
) )
self.secret = secret self.secret = secret
self.token = token self.token = token
self.app_id = app_id self.app_id = app_id
self._message_handlers = { self._message_handlers = {}
} self.base_url = 'https://api.sgroup.qq.com'
self.base_url = "https://api.sgroup.qq.com" self.access_token = ''
self.access_token = ""
self.access_token_expiry_time = None self.access_token_expiry_time = None
async def check_access_token(self): async def check_access_token(self):
@@ -69,27 +58,26 @@ class QQOfficialClient:
async def get_access_token(self): async def get_access_token(self):
"""获取access_token""" """获取access_token"""
url = "https://bots.qq.com/app/getAppAccessToken" url = 'https://bots.qq.com/app/getAppAccessToken'
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
params = { params = {
"appId":self.app_id, 'appId': self.app_id,
"clientSecret":self.secret, 'clientSecret': self.secret,
} }
headers = { headers = {
"content-type":"application/json", 'content-type': 'application/json',
} }
try: try:
response = await client.post(url,json=params,headers=headers) response = await client.post(url, json=params, headers=headers)
if response.status_code == 200: if response.status_code == 200:
response_data = response.json() response_data = response.json()
access_token = response_data.get("access_token") access_token = response_data.get('access_token')
expires_in = int(response_data.get("expires_in",7200)) expires_in = int(response_data.get('expires_in', 7200))
self.access_token_expiry_time = time.time() + expires_in - 60 self.access_token_expiry_time = time.time() + expires_in - 60
if access_token: if access_token:
self.access_token = access_token self.access_token = access_token
except Exception as e: except Exception as e:
raise Exception(f"获取access_token失败: {e}") raise Exception(f'获取access_token失败: {e}')
async def handle_callback_request(self): async def handle_callback_request(self):
"""处理回调请求""" """处理回调请求"""
@@ -98,27 +86,24 @@ class QQOfficialClient:
body = await request.get_data() body = await request.get_data()
payload = json.loads(body) payload = json.loads(body)
# 验证是否为回调验证请求 # 验证是否为回调验证请求
if payload.get("op") == 13: if payload.get('op') == 13:
# 生成签名 # 生成签名
response = handle_validation(payload, self.secret) response = handle_validation(payload, self.secret)
return response return response
if payload.get("op") == 0: if payload.get('op') == 0:
message_data = await self.get_message(payload) message_data = await self.get_message(payload)
if message_data: if message_data:
event = QQOfficialEvent.from_payload(message_data) event = QQOfficialEvent.from_payload(message_data)
await self._handle_message(event) await self._handle_message(event)
return {"code": 0, "message": "success"} return {'code': 0, 'message': 'success'}
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
return {"error": str(e)}, 400 return {'error': str(e)}, 400
async def run_task(self, host: str, port: int, *args, **kwargs): async def run_task(self, host: str, port: int, *args, **kwargs):
"""启动 Quart 应用""" """启动 Quart 应用"""
@@ -135,133 +120,140 @@ class QQOfficialClient:
return decorator return decorator
async def _handle_message(self, event:QQOfficialEvent): async def _handle_message(self, event: QQOfficialEvent):
"""处理消息事件""" """处理消息事件"""
msg_type = event.t msg_type = event.t
if msg_type in self._message_handlers: if msg_type in self._message_handlers:
for handler in self._message_handlers[msg_type]: for handler in self._message_handlers[msg_type]:
await handler(event) await handler(event)
async def get_message(self, msg: dict) -> Dict[str, Any]:
async def get_message(self,msg:dict) -> Dict[str,Any]:
"""获取消息""" """获取消息"""
message_data = { message_data = {
"t": msg.get("t",{}), 't': msg.get('t', {}),
"user_openid": msg.get("d",{}).get("author",{}).get("user_openid",{}), 'user_openid': msg.get('d', {}).get('author', {}).get('user_openid', {}),
"timestamp": msg.get("d",{}).get("timestamp",{}), 'timestamp': msg.get('d', {}).get('timestamp', {}),
"d_author_id": msg.get("d",{}).get("author",{}).get("id",{}), 'd_author_id': msg.get('d', {}).get('author', {}).get('id', {}),
"content": msg.get("d",{}).get("content",{}), 'content': msg.get('d', {}).get('content', {}),
"d_id": msg.get("d",{}).get("id",{}), 'd_id': msg.get('d', {}).get('id', {}),
"id": msg.get("id",{}), 'id': msg.get('id', {}),
"channel_id": msg.get("d",{}).get("channel_id",{}), 'channel_id': msg.get('d', {}).get('channel_id', {}),
"username": msg.get("d",{}).get("author",{}).get("username",{}), 'username': msg.get('d', {}).get('author', {}).get('username', {}),
"guild_id": msg.get("d",{}).get("guild_id",{}), 'guild_id': msg.get('d', {}).get('guild_id', {}),
"member_openid": msg.get("d",{}).get("author",{}).get("openid",{}), 'member_openid': msg.get('d', {}).get('author', {}).get('openid', {}),
"group_openid": msg.get("d",{}).get("group_openid",{}) 'group_openid': msg.get('d', {}).get('group_openid', {}),
} }
attachments = msg.get("d", {}).get("attachments", []) attachments = msg.get('d', {}).get('attachments', [])
image_attachments = [attachment['url'] for attachment in attachments if await self.is_image(attachment)] image_attachments = [
image_attachments_type = [attachment['content_type'] for attachment in attachments if await self.is_image(attachment)] attachment['url']
for attachment in attachments
if await self.is_image(attachment)
]
image_attachments_type = [
attachment['content_type']
for attachment in attachments
if await self.is_image(attachment)
]
if image_attachments: if image_attachments:
message_data["image_attachments"] = image_attachments[0] message_data['image_attachments'] = image_attachments[0]
message_data["content_type"] = image_attachments_type[0] message_data['content_type'] = image_attachments_type[0]
else: else:
message_data['image_attachments'] = None
message_data["image_attachments"] = None
return message_data return message_data
async def is_image(self, attachment: dict) -> bool:
async def is_image(self,attachment:dict) -> bool:
"""判断是否为图片附件""" """判断是否为图片附件"""
content_type = attachment.get("content_type","") content_type = attachment.get('content_type', '')
return content_type.startswith("image/") return content_type.startswith('image/')
async def send_private_text_msg(self, user_openid: str, content: str, msg_id: str):
async def send_private_text_msg(self,user_openid:str,content:str,msg_id:str):
"""发送私聊消息""" """发送私聊消息"""
if not await self.check_access_token(): if not await self.check_access_token():
await self.get_access_token() await self.get_access_token()
url = self.base_url + "/v2/users/" + user_openid + "/messages" url = self.base_url + '/v2/users/' + user_openid + '/messages'
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
headers = { headers = {
"Authorization": f"QQBot {self.access_token}", 'Authorization': f'QQBot {self.access_token}',
"Content-Type": "application/json", 'Content-Type': 'application/json',
} }
data = { data = {
"content": content, 'content': content,
"msg_type": 0, 'msg_type': 0,
"msg_id": msg_id, 'msg_id': msg_id,
} }
response = await client.post(url,headers=headers,json=data) response = await client.post(url, headers=headers, json=data)
if response.status_code == 200: if response.status_code == 200:
return return
else: else:
raise ValueError(response) raise ValueError(response)
async def send_group_text_msg(self, group_openid: str, content: str, msg_id: str):
async def send_group_text_msg(self,group_openid:str,content:str,msg_id:str):
"""发送群聊消息""" """发送群聊消息"""
if not await self.check_access_token(): if not await self.check_access_token():
await self.get_access_token() await self.get_access_token()
url = self.base_url + "/v2/groups/" + group_openid + "/messages" url = self.base_url + '/v2/groups/' + group_openid + '/messages'
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
headers = { headers = {
"Authorization": f"QQBot {self.access_token}", 'Authorization': f'QQBot {self.access_token}',
"Content-Type": "application/json", 'Content-Type': 'application/json',
} }
data = { data = {
"content": content, 'content': content,
"msg_type": 0, 'msg_type': 0,
"msg_id": msg_id, 'msg_id': msg_id,
} }
response = await client.post(url,headers=headers,json=data) response = await client.post(url, headers=headers, json=data)
if response.status_code == 200: if response.status_code == 200:
return return
else: else:
raise Exception(response.read().decode()) raise Exception(response.read().decode())
async def send_channle_group_text_msg(self,channel_id:str,content:str,msg_id:str): async def send_channle_group_text_msg(
self, channel_id: str, content: str, msg_id: str
):
"""发送频道群聊消息""" """发送频道群聊消息"""
if not await self.check_access_token(): if not await self.check_access_token():
await self.get_access_token() await self.get_access_token()
url = self.base_url + "/channels/" + channel_id + "/messages" url = self.base_url + '/channels/' + channel_id + '/messages'
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
headers = { headers = {
"Authorization": f"QQBot {self.access_token}", 'Authorization': f'QQBot {self.access_token}',
"Content-Type": "application/json", 'Content-Type': 'application/json',
} }
params = { params = {
"content": content, 'content': content,
"msg_type": 0, 'msg_type': 0,
"msg_id": msg_id, 'msg_id': msg_id,
} }
response = await client.post(url,headers=headers,json=params) response = await client.post(url, headers=headers, json=params)
if response.status_code == 200: if response.status_code == 200:
return True return True
else: else:
raise Exception(response) raise Exception(response)
async def send_channle_private_text_msg(self,guild_id:str,content:str,msg_id:str): async def send_channle_private_text_msg(
self, guild_id: str, content: str, msg_id: str
):
"""发送频道私聊消息""" """发送频道私聊消息"""
if not await self.check_access_token(): if not await self.check_access_token():
await self.get_access_token() await self.get_access_token()
url = self.base_url + "/dms/" + guild_id + "/messages" url = self.base_url + '/dms/' + guild_id + '/messages'
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
headers = { headers = {
"Authorization": f"QQBot {self.access_token}", 'Authorization': f'QQBot {self.access_token}',
"Content-Type": "application/json", 'Content-Type': 'application/json',
} }
params = { params = {
"content": content, 'content': content,
"msg_type": 0, 'msg_type': 0,
"msg_id": msg_id, 'msg_id': msg_id,
} }
response = await client.post(url,headers=headers,json=params) response = await client.post(url, headers=headers, json=params)
if response.status_code == 200: if response.status_code == 200:
return True return True
else: else:
+18 -20
View File
@@ -1,101 +1,100 @@
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
class QQOfficialEvent(dict): class QQOfficialEvent(dict):
@staticmethod @staticmethod
def from_payload(payload: Dict[str, Any]) -> Optional["QQOfficialEvent"]: def from_payload(payload: Dict[str, Any]) -> Optional['QQOfficialEvent']:
try: try:
event = QQOfficialEvent(payload) event = QQOfficialEvent(payload)
return event return event
except KeyError: except KeyError:
return None return None
@property @property
def t(self) -> str: def t(self) -> str:
""" """
事件类型 事件类型
""" """
return self.get("t", "") return self.get('t', '')
@property @property
def user_openid(self) -> str: def user_openid(self) -> str:
""" """
用户openid 用户openid
""" """
return self.get("user_openid",{}) return self.get('user_openid', {})
@property @property
def timestamp(self) -> str: def timestamp(self) -> str:
""" """
时间戳 时间戳
""" """
return self.get("timestamp",{}) return self.get('timestamp', {})
@property @property
def d_author_id(self) -> str: def d_author_id(self) -> str:
""" """
作者id 作者id
""" """
return self.get("id",{}) return self.get('id', {})
@property @property
def content(self) -> str: def content(self) -> str:
""" """
内容 内容
""" """
return self.get("content",'') return self.get('content', '')
@property @property
def d_id(self) -> str: def d_id(self) -> str:
""" """
d_id d_id
""" """
return self.get("d_id",{}) return self.get('d_id', {})
@property @property
def id(self) -> str: def id(self) -> str:
""" """
消息idmsg_id 消息idmsg_id
""" """
return self.get("id",{}) return self.get('id', {})
@property @property
def channel_id(self) -> str: def channel_id(self) -> str:
""" """
频道id 频道id
""" """
return self.get("channel_id",{}) return self.get('channel_id', {})
@property @property
def username(self) -> str: def username(self) -> str:
""" """
用户名 用户名
""" """
return self.get("username",{}) return self.get('username', {})
@property @property
def guild_id(self) -> str: def guild_id(self) -> str:
""" """
频道id 频道id
""" """
return self.get("guild_id",{}) return self.get('guild_id', {})
@property @property
def member_openid(self) -> str: def member_openid(self) -> str:
""" """
成员openid 成员openid
""" """
return self.get("openid",{}) return self.get('openid', {})
@property @property
def attachments(self) -> str: def attachments(self) -> str:
""" """
附件url 附件url
""" """
url = self.get("image_attachments", "") url = self.get('image_attachments', '')
if url and not url.startswith("https://"): if url and not url.startswith('https://'):
url = "https://" + url url = 'https://' + url
return url return url
@property @property
@@ -103,12 +102,11 @@ class QQOfficialEvent(dict):
""" """
群组id 群组id
""" """
return self.get("group_openid",{}) return self.get('group_openid', {})
@property @property
def content_type(self) -> str: def content_type(self) -> str:
""" """
文件类型 文件类型
""" """
return self.get("content_type","") return self.get('content_type', '')
+19 -14
View File
@@ -1,10 +1,11 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding:utf-8 -*- # -*- encoding:utf-8 -*-
""" 对企业微信发送给企业后台的消息加解密示例代码. """对企业微信发送给企业后台的消息加解密示例代码.
@copyright: Copyright (c) 1998-2014 Tencent Inc. @copyright: Copyright (c) 1998-2014 Tencent Inc.
""" """
# ------------------------------------------------------------------------ # ------------------------------------------------------------------------
import logging import logging
import base64 import base64
@@ -49,7 +50,7 @@ class SHA1:
sortlist = [token, timestamp, nonce, encrypt] sortlist = [token, timestamp, nonce, encrypt]
sortlist.sort() sortlist.sort()
sha = hashlib.sha1() sha = hashlib.sha1()
sha.update("".join(sortlist).encode()) sha.update(''.join(sortlist).encode())
return ierror.WXBizMsgCrypt_OK, sha.hexdigest() return ierror.WXBizMsgCrypt_OK, sha.hexdigest()
except Exception as e: except Exception as e:
logger = logging.getLogger() logger = logging.getLogger()
@@ -75,7 +76,7 @@ class XMLParse:
""" """
try: try:
xml_tree = ET.fromstring(xmltext) xml_tree = ET.fromstring(xmltext)
encrypt = xml_tree.find("Encrypt") encrypt = xml_tree.find('Encrypt')
return ierror.WXBizMsgCrypt_OK, encrypt.text return ierror.WXBizMsgCrypt_OK, encrypt.text
except Exception as e: except Exception as e:
logger = logging.getLogger() logger = logging.getLogger()
@@ -100,13 +101,13 @@ class XMLParse:
return resp_xml return resp_xml
class PKCS7Encoder(): class PKCS7Encoder:
"""提供基于PKCS7算法的加解密接口""" """提供基于PKCS7算法的加解密接口"""
block_size = 32 block_size = 32
def encode(self, text): def encode(self, text):
""" 对需要加密的明文进行填充补位 """对需要加密的明文进行填充补位
@param text: 需要进行填充补位操作的明文 @param text: 需要进行填充补位操作的明文
@return: 补齐明文字符串 @return: 补齐明文字符串
""" """
@@ -134,7 +135,6 @@ class Prpcrypt(object):
"""提供接收和推送给企业微信消息的加解密接口""" """提供接收和推送给企业微信消息的加解密接口"""
def __init__(self, key): def __init__(self, key):
# self.key = base64.b64decode(key+"=") # self.key = base64.b64decode(key+"=")
self.key = key self.key = key
# 设置加解密模式为AES的CBC模式 # 设置加解密模式为AES的CBC模式
@@ -147,7 +147,12 @@ class Prpcrypt(object):
""" """
# 16位随机字符串添加到明文开头 # 16位随机字符串添加到明文开头
text = text.encode() text = text.encode()
text = self.get_random_str() + struct.pack("I", socket.htonl(len(text))) + text + receiveid.encode() text = (
self.get_random_str()
+ struct.pack('I', socket.htonl(len(text)))
+ text
+ receiveid.encode()
)
# 使用自定义的填充方式对明文进行补位填充 # 使用自定义的填充方式对明文进行补位填充
pkcs7 = PKCS7Encoder() pkcs7 = PKCS7Encoder()
@@ -183,9 +188,9 @@ class Prpcrypt(object):
# plain_text = pkcs7.encode(plain_text) # plain_text = pkcs7.encode(plain_text)
# 去除16位随机字符串 # 去除16位随机字符串
content = plain_text[16:-pad] content = plain_text[16:-pad]
xml_len = socket.ntohl(struct.unpack("I", content[: 4])[0]) xml_len = socket.ntohl(struct.unpack('I', content[:4])[0])
xml_content = content[4: xml_len + 4] xml_content = content[4 : xml_len + 4]
from_receiveid = content[xml_len + 4:] from_receiveid = content[xml_len + 4 :]
except Exception as e: except Exception as e:
logger = logging.getLogger() logger = logging.getLogger()
logger.error(e) logger.error(e)
@@ -196,7 +201,7 @@ class Prpcrypt(object):
return 0, xml_content return 0, xml_content
def get_random_str(self): def get_random_str(self):
""" 随机生成16位字符串 """随机生成16位字符串
@return: 16位字符串 @return: 16位字符串
""" """
return str(random.randint(1000000000000000, 9999999999999999)).encode() return str(random.randint(1000000000000000, 9999999999999999)).encode()
@@ -206,10 +211,10 @@ class WXBizMsgCrypt(object):
# 构造函数 # 构造函数
def __init__(self, sToken, sEncodingAESKey, sReceiveId): def __init__(self, sToken, sEncodingAESKey, sReceiveId):
try: try:
self.key = base64.b64decode(sEncodingAESKey + "=") self.key = base64.b64decode(sEncodingAESKey + '=')
assert len(self.key) == 32 assert len(self.key) == 32
except: except Exception:
throw_exception("[error]: EncodingAESKey unvalid !", FormatException) throw_exception('[error]: EncodingAESKey unvalid !', FormatException)
# return ierror.WXBizMsgCrypt_IllegalAesKey,None # return ierror.WXBizMsgCrypt_IllegalAesKey,None
self.m_sToken = sToken self.m_sToken = sToken
self.m_sReceiveId = sReceiveId self.m_sReceiveId = sReceiveId
+158 -110
View File
@@ -7,15 +7,22 @@ from quart import Quart
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from typing import Callable, Dict, Any from typing import Callable, Dict, Any
from .wecomevent import WecomEvent from .wecomevent import WecomEvent
from pkg.platform.types import events as platform_events, message as platform_message from pkg.platform.types import message as platform_message
import aiofiles import aiofiles
class WecomClient(): class WecomClient:
def __init__(self,corpid:str,secret:str,token:str,EncodingAESKey:str,contacts_secret:str): def __init__(
self,
corpid: str,
secret: str,
token: str,
EncodingAESKey: str,
contacts_secret: str,
):
self.corpid = corpid self.corpid = corpid
self.secret = secret self.secret = secret
self.access_token_for_contacts ='' self.access_token_for_contacts = ''
self.token = token self.token = token
self.aes = EncodingAESKey self.aes = EncodingAESKey
self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin' self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin'
@@ -23,19 +30,26 @@ class WecomClient():
self.secret_for_contacts = contacts_secret self.secret_for_contacts = contacts_secret
self.app = Quart(__name__) self.app = Quart(__name__)
self.wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid) self.wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid)
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']) self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['GET', 'POST'],
)
self._message_handlers = { self._message_handlers = {
"example":[], 'example': [],
} }
#access——token操作 # access——token操作
async def check_access_token(self): async def check_access_token(self):
return bool(self.access_token and self.access_token.strip()) return bool(self.access_token and self.access_token.strip())
async def check_access_token_for_contacts(self): async def check_access_token_for_contacts(self):
return bool(self.access_token_for_contacts and self.access_token_for_contacts.strip()) return bool(
self.access_token_for_contacts and self.access_token_for_contacts.strip()
)
async def get_access_token(self,secret): async def get_access_token(self, secret):
url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corpid}&corpsecret={secret}' url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corpid}&corpsecret={secret}'
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get(url) response = await client.get(url)
@@ -43,146 +57,163 @@ class WecomClient():
if 'access_token' in data: if 'access_token' in data:
return data['access_token'] return data['access_token']
else: else:
raise Exception(f"未获取access token: {data}") raise Exception(f'未获取access token: {data}')
async def get_users(self): async def get_users(self):
if not self.check_access_token_for_contacts(): if not self.check_access_token_for_contacts():
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts) self.access_token_for_contacts = await self.get_access_token(
self.secret_for_contacts
)
url = self.base_url+'/user/list_id?access_token='+self.access_token_for_contacts url = (
self.base_url
+ '/user/list_id?access_token='
+ self.access_token_for_contacts
)
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
params = { params = {
"cursor":"", 'cursor': '',
"limit":10000, 'limit': 10000,
} }
response = await client.post(url,json=params) response = await client.post(url, json=params)
data = response.json() data = response.json()
if data['errcode'] == 0: if data['errcode'] == 0:
dept_users = data['dept_user'] dept_users = data['dept_user']
userid = [] userid = []
for user in dept_users: for user in dept_users:
userid.append(user["userid"]) userid.append(user['userid'])
return userid return userid
else: else:
raise Exception("未获取用户") raise Exception('未获取用户')
async def send_to_all(self,content:str,agent_id:int): async def send_to_all(self, content: str, agent_id: int):
if not self.check_access_token_for_contacts(): if not self.check_access_token_for_contacts():
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts) self.access_token_for_contacts = await self.get_access_token(
self.secret_for_contacts
)
url = self.base_url+'/message/send?access_token='+self.access_token_for_contacts url = (
self.base_url
+ '/message/send?access_token='
+ self.access_token_for_contacts
)
user_ids = await self.get_users() user_ids = await self.get_users()
user_ids_string = "|".join(user_ids) user_ids_string = '|'.join(user_ids)
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
params = { params = {
"touser" : user_ids_string, 'touser': user_ids_string,
"msgtype" : "text", 'msgtype': 'text',
"agentid" : agent_id, 'agentid': agent_id,
"text" : { 'text': {
"content" : content, 'content': content,
}, },
"safe":0, 'safe': 0,
"enable_id_trans": 0, 'enable_id_trans': 0,
"enable_duplicate_check": 0, 'enable_duplicate_check': 0,
"duplicate_check_interval": 1800 'duplicate_check_interval': 1800,
} }
response = await client.post(url,json=params) response = await client.post(url, json=params)
data = response.json() data = response.json()
if data['errcode'] != 0: if data['errcode'] != 0:
raise Exception("Failed to send message: "+str(data)) raise Exception('Failed to send message: ' + str(data))
async def send_image(self,user_id:str,agent_id:int,media_id:str): async def send_image(self, user_id: str, agent_id: int, media_id: str):
if not await self.check_access_token(): if not await self.check_access_token():
self.access_token = await self.get_access_token(self.secret) self.access_token = await self.get_access_token(self.secret)
url = self.base_url+'/media/upload?access_token='+self.access_token url = self.base_url + '/media/upload?access_token=' + self.access_token
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
params = { params = {
"touser" : user_id, 'touser': user_id,
"toparty" : "", 'toparty': '',
"totag":"", 'totag': '',
"agentid" : agent_id, 'agentid': agent_id,
"msgtype" : "image", 'msgtype': 'image',
"image" : { 'image': {
"media_id" : media_id, 'media_id': media_id,
}, },
"safe":0, 'safe': 0,
"enable_id_trans": 0, 'enable_id_trans': 0,
"enable_duplicate_check": 0, 'enable_duplicate_check': 0,
"duplicate_check_interval": 1800 'duplicate_check_interval': 1800,
} }
try: try:
response = await client.post(url,json=params) response = await client.post(url, json=params)
data = response.json() data = response.json()
except Exception as e: except Exception as e:
raise Exception("Failed to send image: "+str(e)) raise Exception('Failed to send image: ' + str(e))
# 企业微信错误码40014和42001,代表accesstoken问题 # 企业微信错误码40014和42001,代表accesstoken问题
if data['errcode'] == 40014 or data['errcode'] == 42001: if data['errcode'] == 40014 or data['errcode'] == 42001:
self.access_token = await self.get_access_token(self.secret) self.access_token = await self.get_access_token(self.secret)
return await self.send_image(user_id,agent_id,media_id) return await self.send_image(user_id, agent_id, media_id)
if data['errcode'] != 0: if data['errcode'] != 0:
raise Exception("Failed to send image: "+str(data)) raise Exception('Failed to send image: ' + str(data))
async def send_private_msg(self,user_id:str, agent_id:int,content:str): async def send_private_msg(self, user_id: str, agent_id: int, content: str):
if not await self.check_access_token(): if not await self.check_access_token():
self.access_token = await self.get_access_token(self.secret) self.access_token = await self.get_access_token(self.secret)
url = self.base_url+'/message/send?access_token='+self.access_token url = self.base_url + '/message/send?access_token=' + self.access_token
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
params={ params = {
"touser" : user_id, 'touser': user_id,
"msgtype" : "text", 'msgtype': 'text',
"agentid" : agent_id, 'agentid': agent_id,
"text" : { 'text': {
"content" : content, 'content': content,
}, },
"safe":0, 'safe': 0,
"enable_id_trans": 0, 'enable_id_trans': 0,
"enable_duplicate_check": 0, 'enable_duplicate_check': 0,
"duplicate_check_interval": 1800 'duplicate_check_interval': 1800,
} }
response = await client.post(url,json=params) response = await client.post(url, json=params)
data = response.json() data = response.json()
if data['errcode'] == 40014 or data['errcode'] == 42001: if data['errcode'] == 40014 or data['errcode'] == 42001:
self.access_token = await self.get_access_token(self.secret) self.access_token = await self.get_access_token(self.secret)
return await self.send_private_msg(user_id,agent_id,content) return await self.send_private_msg(user_id, agent_id, content)
if data['errcode'] != 0: if data['errcode'] != 0:
raise Exception("Failed to send message: "+str(data)) raise Exception('Failed to send message: ' + str(data))
async def handle_callback_request(self): async def handle_callback_request(self):
""" """
处理回调请求,包括 GET 验证和 POST 消息接收。 处理回调请求,包括 GET 验证和 POST 消息接收。
""" """
try: try:
msg_signature = request.args.get('msg_signature')
timestamp = request.args.get('timestamp')
nonce = request.args.get('nonce')
msg_signature = request.args.get("msg_signature") if request.method == 'GET':
timestamp = request.args.get("timestamp") echostr = request.args.get('echostr')
nonce = request.args.get("nonce") ret, reply_echo_str = self.wxcpt.VerifyURL(
msg_signature, timestamp, nonce, echostr
if request.method == "GET": )
echostr = request.args.get("echostr")
ret, reply_echo_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
if ret != 0: if ret != 0:
raise Exception(f"验证失败,错误码: {ret}") raise Exception(f'验证失败,错误码: {ret}')
return reply_echo_str return reply_echo_str
elif request.method == "POST": elif request.method == 'POST':
encrypt_msg = await request.data encrypt_msg = await request.data
ret, xml_msg = self.wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce) ret, xml_msg = self.wxcpt.DecryptMsg(
encrypt_msg, msg_signature, timestamp, nonce
)
if ret != 0: if ret != 0:
raise Exception(f"消息解密失败,错误码: {ret}") raise Exception(f'消息解密失败,错误码: {ret}')
# 解析消息并处理 # 解析消息并处理
message_data = await self.get_message(xml_msg) message_data = await self.get_message(xml_msg)
if message_data: if message_data:
event = WecomEvent.from_payload(message_data) # 转换为 WecomEvent 对象 event = WecomEvent.from_payload(
message_data
) # 转换为 WecomEvent 对象
if event: if event:
await self._handle_message(event) await self._handle_message(event)
return "success" return 'success'
except Exception as e: except Exception as e:
return f"Error processing request: {str(e)}", 400 return f'Error processing request: {str(e)}', 400
async def run_task(self, host: str, port: int, *args, **kwargs): async def run_task(self, host: str, port: int, *args, **kwargs):
""" """
@@ -194,11 +225,13 @@ class WecomClient():
""" """
注册消息类型处理器。 注册消息类型处理器。
""" """
def decorator(func: Callable[[WecomEvent], None]): def decorator(func: Callable[[WecomEvent], None]):
if msg_type not in self._message_handlers: if msg_type not in self._message_handlers:
self._message_handlers[msg_type] = [] self._message_handlers[msg_type] = []
self._message_handlers[msg_type].append(func) self._message_handlers[msg_type].append(func)
return func return func
return decorator return decorator
async def _handle_message(self, event: WecomEvent): async def _handle_message(self, event: WecomEvent):
@@ -216,17 +249,27 @@ class WecomClient():
""" """
root = ET.fromstring(xml_msg) root = ET.fromstring(xml_msg)
message_data = { message_data = {
"ToUserName": root.find("ToUserName").text, 'ToUserName': root.find('ToUserName').text,
"FromUserName": root.find("FromUserName").text, 'FromUserName': root.find('FromUserName').text,
"CreateTime": int(root.find("CreateTime").text), 'CreateTime': int(root.find('CreateTime').text),
"MsgType": root.find("MsgType").text, 'MsgType': root.find('MsgType').text,
"Content": root.find("Content").text if root.find("Content") is not None else None, 'Content': root.find('Content').text
"MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None, if root.find('Content') is not None
"AgentID": int(root.find("AgentID").text) if root.find("AgentID") is not None else None, else None,
'MsgId': int(root.find('MsgId').text)
if root.find('MsgId') is not None
else None,
'AgentID': int(root.find('AgentID').text)
if root.find('AgentID') is not None
else None,
} }
if message_data["MsgType"] == "image": if message_data['MsgType'] == 'image':
message_data["MediaId"] = root.find("MediaId").text if root.find("MediaId") is not None else None message_data['MediaId'] = (
message_data["PicUrl"] = root.find("PicUrl").text if root.find("PicUrl") is not None else None root.find('MediaId').text if root.find('MediaId') is not None else None
)
message_data['PicUrl'] = (
root.find('PicUrl').text if root.find('PicUrl') is not None else None
)
return message_data return message_data
@@ -236,11 +279,11 @@ class WecomClient():
通过图片的magic numbers判断图片类型 通过图片的magic numbers判断图片类型
""" """
magic_numbers = { magic_numbers = {
b'\xFF\xD8\xFF': 'jpg', b'\xff\xd8\xff': 'jpg',
b'\x89\x50\x4E\x47': 'png', b'\x89\x50\x4e\x47': 'png',
b'\x47\x49\x46': 'gif', b'\x47\x49\x46': 'gif',
b'\x42\x4D': 'bmp', b'\x42\x4d': 'bmp',
b'\x00\x00\x01\x00': 'ico' b'\x00\x00\x01\x00': 'ico',
} }
for magic, ext in magic_numbers.items(): for magic, ext in magic_numbers.items():
@@ -248,7 +291,6 @@ class WecomClient():
return ext return ext
return 'jpg' # 默认返回jpg return 'jpg' # 默认返回jpg
async def upload_to_work(self, image: platform_message.Image): async def upload_to_work(self, image: platform_message.Image):
""" """
获取 media_id 获取 media_id
@@ -256,9 +298,14 @@ class WecomClient():
if not await self.check_access_token(): if not await self.check_access_token():
self.access_token = await self.get_access_token(self.secret) self.access_token = await self.get_access_token(self.secret)
url = self.base_url + '/media/upload?access_token=' + self.access_token + '&type=file' url = (
self.base_url
+ '/media/upload?access_token='
+ self.access_token
+ '&type=file'
)
file_bytes = None file_bytes = None
file_name = "uploaded_file.txt" file_name = 'uploaded_file.txt'
# 获取文件的二进制数据 # 获取文件的二进制数据
if image.path: if image.path:
@@ -277,20 +324,22 @@ class WecomClient():
padded_base64 = base64_data + '=' * padding padded_base64 = base64_data + '=' * padding
file_bytes = base64.b64decode(padded_base64) file_bytes = base64.b64decode(padded_base64)
except binascii.Error as e: except binascii.Error as e:
raise ValueError(f"Invalid base64 string: {str(e)}") raise ValueError(f'Invalid base64 string: {str(e)}')
else: else:
raise ValueError("image对象出错") raise ValueError('image对象出错')
# 设置 multipart/form-data 格式的文件 # 设置 multipart/form-data 格式的文件
boundary = "-------------------------acebdf13572468" boundary = '-------------------------acebdf13572468'
headers = { headers = {'Content-Type': f'multipart/form-data; boundary={boundary}'}
'Content-Type': f'multipart/form-data; boundary={boundary}'
}
body = ( body = (
f"--{boundary}\r\n" (
f"Content-Disposition: form-data; name=\"media\"; filename=\"{file_name}\"; filelength={len(file_bytes)}\r\n" f'--{boundary}\r\n'
f"Content-Type: application/octet-stream\r\n\r\n" f'Content-Disposition: form-data; name="media"; filename="{file_name}"; filelength={len(file_bytes)}\r\n'
).encode('utf-8') + file_bytes + f"\r\n--{boundary}--\r\n".encode('utf-8') f'Content-Type: application/octet-stream\r\n\r\n'
).encode('utf-8')
+ file_bytes
+ f'\r\n--{boundary}--\r\n'.encode('utf-8')
)
# 上传文件 # 上传文件
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
@@ -300,19 +349,18 @@ class WecomClient():
self.access_token = await self.get_access_token(self.secret) self.access_token = await self.get_access_token(self.secret)
media_id = await self.upload_to_work(image) media_id = await self.upload_to_work(image)
if data.get('errcode', 0) != 0: if data.get('errcode', 0) != 0:
raise Exception("failed to upload file") raise Exception('failed to upload file')
media_id = data.get('media_id') media_id = data.get('media_id')
return media_id return media_id
async def download_image_to_bytes(self,url:str) -> bytes: async def download_image_to_bytes(self, url: str) -> bytes:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get(url) response = await client.get(url)
response.raise_for_status() response.raise_for_status()
return response.content return response.content
#进行media_id的获取 # 进行media_id的获取
async def get_media_id(self, image: platform_message.Image): async def get_media_id(self, image: platform_message.Image):
media_id = await self.upload_to_work(image=image) media_id = await self.upload_to_work(image=image)
return media_id return media_id
+15 -15
View File
@@ -9,7 +9,7 @@ class WecomEvent(dict):
""" """
@staticmethod @staticmethod
def from_payload(payload: Dict[str, Any]) -> Optional["WecomEvent"]: def from_payload(payload: Dict[str, Any]) -> Optional['WecomEvent']:
""" """
从企业微信事件数据构造 `WecomEvent` 对象 从企业微信事件数据构造 `WecomEvent` 对象
@@ -34,14 +34,14 @@ class WecomEvent(dict):
Returns: Returns:
str: 事件类型 str: 事件类型
""" """
return self.get("MsgType", "") return self.get('MsgType', '')
@property @property
def picurl(self) -> str: def picurl(self) -> str:
""" """
图片链接 图片链接
""" """
return self.get("PicUrl") return self.get('PicUrl')
@property @property
def detail_type(self) -> str: def detail_type(self) -> str:
@@ -53,8 +53,8 @@ class WecomEvent(dict):
Returns: Returns:
str: 事件详细类型 str: 事件详细类型
""" """
if self.type == "event": if self.type == 'event':
return self.get("Event", "") return self.get('Event', '')
return self.type return self.type
@property @property
@@ -65,7 +65,7 @@ class WecomEvent(dict):
Returns: Returns:
str: 事件名 str: 事件名
""" """
return f"{self.type}.{self.detail_type}" return f'{self.type}.{self.detail_type}'
@property @property
def user_id(self) -> Optional[str]: def user_id(self) -> Optional[str]:
@@ -75,7 +75,7 @@ class WecomEvent(dict):
Returns: Returns:
Optional[str]: 用户 ID Optional[str]: 用户 ID
""" """
return self.get("FromUserName") return self.get('FromUserName')
@property @property
def agent_id(self) -> Optional[int]: def agent_id(self) -> Optional[int]:
@@ -85,7 +85,7 @@ class WecomEvent(dict):
Returns: Returns:
Optional[int]: 机器人 ID Optional[int]: 机器人 ID
""" """
return self.get("AgentID") return self.get('AgentID')
@property @property
def receiver_id(self) -> Optional[str]: def receiver_id(self) -> Optional[str]:
@@ -95,7 +95,7 @@ class WecomEvent(dict):
Returns: Returns:
Optional[str]: 接收者 ID Optional[str]: 接收者 ID
""" """
return self.get("ToUserName") return self.get('ToUserName')
@property @property
def message_id(self) -> Optional[str]: def message_id(self) -> Optional[str]:
@@ -105,7 +105,7 @@ class WecomEvent(dict):
Returns: Returns:
Optional[str]: 消息 ID Optional[str]: 消息 ID
""" """
return self.get("MsgId") return self.get('MsgId')
@property @property
def message(self) -> Optional[str]: def message(self) -> Optional[str]:
@@ -115,7 +115,7 @@ class WecomEvent(dict):
Returns: Returns:
Optional[str]: 消息内容 Optional[str]: 消息内容
""" """
return self.get("Content") return self.get('Content')
@property @property
def media_id(self) -> Optional[str]: def media_id(self) -> Optional[str]:
@@ -125,7 +125,7 @@ class WecomEvent(dict):
Returns: Returns:
Optional[str]: 媒体文件 ID Optional[str]: 媒体文件 ID
""" """
return self.get("MediaId") return self.get('MediaId')
@property @property
def timestamp(self) -> Optional[int]: def timestamp(self) -> Optional[int]:
@@ -135,7 +135,7 @@ class WecomEvent(dict):
Returns: Returns:
Optional[int]: 时间戳 Optional[int]: 时间戳
""" """
return self.get("CreateTime") return self.get('CreateTime')
@property @property
def event_key(self) -> Optional[str]: def event_key(self) -> Optional[str]:
@@ -145,7 +145,7 @@ class WecomEvent(dict):
Returns: Returns:
Optional[str]: 事件 Key Optional[str]: 事件 Key
""" """
return self.get("EventKey") return self.get('EventKey')
def __getattr__(self, key: str) -> Optional[Any]: def __getattr__(self, key: str) -> Optional[Any]:
""" """
@@ -176,4 +176,4 @@ class WecomEvent(dict):
Returns: Returns:
str: 字符串表示 str: 字符串表示
""" """
return f"<WecomEvent {super().__repr__()}>" return f'<WecomEvent {super().__repr__()}>'
+14 -13
View File
@@ -1,3 +1,4 @@
import asyncio
# LangBot 终端启动入口 # LangBot 终端启动入口
# 在此层级解决依赖项检查。 # 在此层级解决依赖项检查。
# LangBot/main.py # LangBot/main.py
@@ -14,9 +15,6 @@ asciiart = r"""
""" """
import asyncio
async def main_entry(loop: asyncio.AbstractEventLoop): async def main_entry(loop: asyncio.AbstractEventLoop):
print(asciiart) print(asciiart)
@@ -29,11 +27,11 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
missing_deps = await deps.check_deps() missing_deps = await deps.check_deps()
if missing_deps: if missing_deps:
print("以下依赖包未安装,将自动安装,请完成后重启程序:") print('以下依赖包未安装,将自动安装,请完成后重启程序:')
for dep in missing_deps: for dep in missing_deps:
print("-", dep) print('-', dep)
await deps.install_deps(missing_deps) await deps.install_deps(missing_deps)
print("已自动安装缺失的依赖包,请重启程序。") print('已自动安装缺失的依赖包,请重启程序。')
sys.exit(0) sys.exit(0)
# check plugin deps # check plugin deps
@@ -41,8 +39,10 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
# 检查pydantic版本,如果没有 pydantic.v1,则把 pydantic 映射为 v1 # 检查pydantic版本,如果没有 pydantic.v1,则把 pydantic 映射为 v1
import pydantic.version import pydantic.version
if pydantic.version.VERSION < '2.0': if pydantic.version.VERSION < '2.0':
import pydantic import pydantic
sys.modules['pydantic.v1'] = pydantic sys.modules['pydantic.v1'] = pydantic
# 检查配置文件 # 检查配置文件
@@ -52,11 +52,12 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
generated_files = await files.generate_files() generated_files = await files.generate_files()
if generated_files: if generated_files:
print("以下文件不存在,已自动生成:") print('以下文件不存在,已自动生成:')
for file in generated_files: for file in generated_files:
print("-", file) print('-', file)
from pkg.core import boot from pkg.core import boot
await boot.main(loop) await boot.main(loop)
@@ -66,8 +67,8 @@ if __name__ == '__main__':
# 必须大于 3.10.1 # 必须大于 3.10.1
if sys.version_info < (3, 10, 1): if sys.version_info < (3, 10, 1):
print("需要 Python 3.10.1 及以上版本,当前 Python 版本为:", sys.version) print('需要 Python 3.10.1 及以上版本,当前 Python 版本为:', sys.version)
input("按任意键退出...") input('按任意键退出...')
exit(1) exit(1)
# 检查本目录是否有main.py,且包含LangBot字符串 # 检查本目录是否有main.py,且包含LangBot字符串
@@ -78,11 +79,11 @@ if __name__ == '__main__':
else: else:
with open('main.py', 'r', encoding='utf-8') as f: with open('main.py', 'r', encoding='utf-8') as f:
content = f.read() content = f.read()
if "LangBot/main.py" not in content: if 'LangBot/main.py' not in content:
invalid_pwd = True invalid_pwd = True
if invalid_pwd: if invalid_pwd:
print("请在 LangBot 项目根目录下以命令形式运行此程序。") print('请在 LangBot 项目根目录下以命令形式运行此程序。')
input("按任意键退出...") input('按任意键退出...')
exit(1) exit(1)
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
+26 -14
View File
@@ -13,6 +13,7 @@ from ....core import app
preregistered_groups: list[type[RouterGroup]] = [] preregistered_groups: list[type[RouterGroup]] = []
"""RouterGroup 的预注册列表""" """RouterGroup 的预注册列表"""
def group_class(name: str, path: str) -> None: def group_class(name: str, path: str) -> None:
"""注册一个 RouterGroup""" """注册一个 RouterGroup"""
@@ -27,12 +28,12 @@ def group_class(name: str, path: str) -> None:
class AuthType(enum.Enum): class AuthType(enum.Enum):
"""认证类型""" """认证类型"""
NONE = 'none' NONE = 'none'
USER_TOKEN = 'user-token' USER_TOKEN = 'user-token'
class RouterGroup(abc.ABC): class RouterGroup(abc.ABC):
name: str name: str
path: str path: str
@@ -49,17 +50,24 @@ class RouterGroup(abc.ABC):
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
def route(self, rule: str, auth_type: AuthType = AuthType.USER_TOKEN, **options: typing.Any) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator def route(
self,
rule: str,
auth_type: AuthType = AuthType.USER_TOKEN,
**options: typing.Any,
) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator
"""注册一个路由""" """注册一个路由"""
def decorator(f: RouteCallable) -> RouteCallable: def decorator(f: RouteCallable) -> RouteCallable:
nonlocal rule nonlocal rule
rule = self.path + rule rule = self.path + rule
async def handler_error(*args, **kwargs): async def handler_error(*args, **kwargs):
if auth_type == AuthType.USER_TOKEN: if auth_type == AuthType.USER_TOKEN:
# 从Authorization头中获取token # 从Authorization头中获取token
token = quart.request.headers.get('Authorization', '').replace('Bearer ', '') token = quart.request.headers.get('Authorization', '').replace(
'Bearer ', ''
)
if not token: if not token:
return self.http_status(401, -1, '未提供有效的用户令牌') return self.http_status(401, -1, '未提供有效的用户令牌')
@@ -75,7 +83,7 @@ class RouterGroup(abc.ABC):
try: try:
return await f(*args, **kwargs) return await f(*args, **kwargs)
except Exception as e: # 自动 500 except Exception: # 自动 500
traceback.print_exc() traceback.print_exc()
# return self.http_status(500, -2, str(e)) # return self.http_status(500, -2, str(e))
return self.http_status(500, -2, 'internal server error') return self.http_status(500, -2, 'internal server error')
@@ -91,19 +99,23 @@ class RouterGroup(abc.ABC):
def success(self, data: typing.Any = None) -> quart.Response: def success(self, data: typing.Any = None) -> quart.Response:
"""返回一个 200 响应""" """返回一个 200 响应"""
return quart.jsonify({ return quart.jsonify(
'code': 0, {
'msg': 'ok', 'code': 0,
'data': data, 'msg': 'ok',
}) 'data': data,
}
)
def fail(self, code: int, msg: str) -> quart.Response: def fail(self, code: int, msg: str) -> quart.Response:
"""返回一个异常响应""" """返回一个异常响应"""
return quart.jsonify({ return quart.jsonify(
'code': code, {
'msg': msg, 'code': code,
}) 'msg': msg,
}
)
def http_status(self, status: int, code: int, msg: str) -> quart.Response: def http_status(self, status: int, code: int, msg: str) -> quart.Response:
"""返回一个指定状态码的响应""" """返回一个指定状态码的响应"""
+7 -10
View File
@@ -1,32 +1,29 @@
from __future__ import annotations from __future__ import annotations
import traceback
import quart import quart
from .....core import app
from .. import group from .. import group
@group.group_class('logs', '/api/v1/logs') @group.group_class('logs', '/api/v1/logs')
class LogsRouterGroup(group.RouterGroup): class LogsRouterGroup(group.RouterGroup):
async def initialize(self) -> None: async def initialize(self) -> None:
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) @self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str: async def _() -> str:
start_page_number = int(quart.request.args.get('start_page_number', 0)) start_page_number = int(quart.request.args.get('start_page_number', 0))
start_offset = int(quart.request.args.get('start_offset', 0)) start_offset = int(quart.request.args.get('start_offset', 0))
logs_str, end_page_number, end_offset = self.ap.log_cache.get_log_by_pointer( logs_str, end_page_number, end_offset = (
start_page_number=start_page_number, self.ap.log_cache.get_log_by_pointer(
start_offset=start_offset start_page_number=start_page_number, start_offset=start_offset
)
) )
return self.success( return self.success(
data={ data={
"logs": logs_str, 'logs': logs_str,
"end_page_number": end_page_number, 'end_page_number': end_page_number,
"end_offset": end_offset 'end_offset': end_offset,
} }
) )
+11 -17
View File
@@ -3,34 +3,31 @@ from __future__ import annotations
import quart import quart
from .. import group from .. import group
from .....entity.persistence import pipeline
@group.group_class('pipelines', '/api/v1/pipelines') @group.group_class('pipelines', '/api/v1/pipelines')
class PipelinesRouterGroup(group.RouterGroup): class PipelinesRouterGroup(group.RouterGroup):
async def initialize(self) -> None: async def initialize(self) -> None:
@self.route('', methods=['GET', 'POST']) @self.route('', methods=['GET', 'POST'])
async def _() -> str: async def _() -> str:
if quart.request.method == 'GET': if quart.request.method == 'GET':
return self.success(data={ return self.success(
'pipelines': await self.ap.pipeline_service.get_pipelines() data={'pipelines': await self.ap.pipeline_service.get_pipelines()}
}) )
elif quart.request.method == 'POST': elif quart.request.method == 'POST':
json_data = await quart.request.json json_data = await quart.request.json
pipeline_uuid = await self.ap.pipeline_service.create_pipeline(json_data) pipeline_uuid = await self.ap.pipeline_service.create_pipeline(
json_data
)
return self.success(data={ return self.success(data={'uuid': pipeline_uuid})
'uuid': pipeline_uuid
})
@self.route('/_/metadata', methods=['GET']) @self.route('/_/metadata', methods=['GET'])
async def _() -> str: async def _() -> str:
return self.success(data={ return self.success(
'configs': await self.ap.pipeline_service.get_pipeline_metadata() data={'configs': await self.ap.pipeline_service.get_pipeline_metadata()}
}) )
@self.route('/<pipeline_uuid>', methods=['GET', 'PUT', 'DELETE']) @self.route('/<pipeline_uuid>', methods=['GET', 'PUT', 'DELETE'])
async def _(pipeline_uuid: str) -> str: async def _(pipeline_uuid: str) -> str:
@@ -40,9 +37,7 @@ class PipelinesRouterGroup(group.RouterGroup):
if pipeline is None: if pipeline is None:
return self.http_status(404, -1, 'pipeline not found') return self.http_status(404, -1, 'pipeline not found')
return self.success(data={ return self.success(data={'pipeline': pipeline})
'pipeline': pipeline
})
elif quart.request.method == 'PUT': elif quart.request.method == 'PUT':
json_data = await quart.request.json json_data = await quart.request.json
@@ -53,4 +48,3 @@ class PipelinesRouterGroup(group.RouterGroup):
await self.ap.pipeline_service.delete_pipeline(pipeline_uuid) await self.ap.pipeline_service.delete_pipeline(pipeline_uuid)
return self.success() return self.success()
@@ -5,29 +5,31 @@ from ... import group
@group.group_class('adapters', '/api/v1/platform/adapters') @group.group_class('adapters', '/api/v1/platform/adapters')
class AdaptersRouterGroup(group.RouterGroup): class AdaptersRouterGroup(group.RouterGroup):
async def initialize(self) -> None: async def initialize(self) -> None:
@self.route('', methods=['GET']) @self.route('', methods=['GET'])
async def _() -> str: async def _() -> str:
return self.success(data={ return self.success(
'adapters': self.ap.platform_mgr.get_available_adapters_info() data={'adapters': self.ap.platform_mgr.get_available_adapters_info()}
}) )
@self.route('/<adapter_name>', methods=['GET']) @self.route('/<adapter_name>', methods=['GET'])
async def _(adapter_name: str) -> str: async def _(adapter_name: str) -> str:
adapter_info = self.ap.platform_mgr.get_available_adapter_info_by_name(adapter_name) adapter_info = self.ap.platform_mgr.get_available_adapter_info_by_name(
adapter_name
)
if adapter_info is None: if adapter_info is None:
return self.http_status(404, -1, 'adapter not found') return self.http_status(404, -1, 'adapter not found')
return self.success(data={ return self.success(data={'adapter': adapter_info})
'adapter': adapter_info
})
@self.route('/<adapter_name>/icon', methods=['GET']) @self.route('/<adapter_name>/icon', methods=['GET'])
async def _(adapter_name: str) -> quart.Response: async def _(adapter_name: str) -> quart.Response:
adapter_manifest = (
adapter_manifest = self.ap.platform_mgr.get_available_adapter_manifest_by_name(adapter_name) self.ap.platform_mgr.get_available_adapter_manifest_by_name(
adapter_name
)
)
if adapter_manifest is None: if adapter_manifest is None:
return self.http_status(404, -1, 'adapter not found') return self.http_status(404, -1, 'adapter not found')
@@ -5,20 +5,15 @@ from ... import group
@group.group_class('bots', '/api/v1/platform/bots') @group.group_class('bots', '/api/v1/platform/bots')
class BotsRouterGroup(group.RouterGroup): class BotsRouterGroup(group.RouterGroup):
async def initialize(self) -> None: async def initialize(self) -> None:
@self.route('', methods=['GET', 'POST']) @self.route('', methods=['GET', 'POST'])
async def _() -> str: async def _() -> str:
if quart.request.method == 'GET': if quart.request.method == 'GET':
return self.success(data={ return self.success(data={'bots': await self.ap.bot_service.get_bots()})
'bots': await self.ap.bot_service.get_bots()
})
elif quart.request.method == 'POST': elif quart.request.method == 'POST':
json_data = await quart.request.json json_data = await quart.request.json
bot_uuid = await self.ap.bot_service.create_bot(json_data) bot_uuid = await self.ap.bot_service.create_bot(json_data)
return self.success(data={ return self.success(data={'uuid': bot_uuid})
'uuid': bot_uuid
})
@self.route('/<bot_uuid>', methods=['GET', 'PUT', 'DELETE']) @self.route('/<bot_uuid>', methods=['GET', 'PUT', 'DELETE'])
async def _(bot_uuid: str) -> str: async def _(bot_uuid: str) -> str:
@@ -26,9 +21,7 @@ class BotsRouterGroup(group.RouterGroup):
bot = await self.ap.bot_service.get_bot(bot_uuid) bot = await self.ap.bot_service.get_bot(bot_uuid)
if bot is None: if bot is None:
return self.http_status(404, -1, 'bot not found') return self.http_status(404, -1, 'bot not found')
return self.success(data={ return self.success(data={'bot': bot})
'bot': bot
})
elif quart.request.method == 'PUT': elif quart.request.method == 'PUT':
json_data = await quart.request.json json_data = await quart.request.json
await self.ap.bot_service.update_bot(bot_uuid, json_data) await self.ap.bot_service.update_bot(bot_uuid, json_data)
+39 -36
View File
@@ -1,17 +1,14 @@
from __future__ import annotations from __future__ import annotations
import traceback
import quart import quart
from .....core import app, taskmgr from .....core import taskmgr
from .. import group from .. import group
@group.group_class('plugins', '/api/v1/plugins') @group.group_class('plugins', '/api/v1/plugins')
class PluginsRouterGroup(group.RouterGroup): class PluginsRouterGroup(group.RouterGroup):
async def initialize(self) -> None: async def initialize(self) -> None:
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) @self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str: async def _() -> str:
@@ -19,63 +16,69 @@ class PluginsRouterGroup(group.RouterGroup):
plugins_data = [plugin.model_dump() for plugin in plugins] plugins_data = [plugin.model_dump() for plugin in plugins]
return self.success(data={ return self.success(data={'plugins': plugins_data})
'plugins': plugins_data
})
@self.route('/<author>/<plugin_name>/toggle', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN) @self.route(
'/<author>/<plugin_name>/toggle',
methods=['PUT'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _(author: str, plugin_name: str) -> str: async def _(author: str, plugin_name: str) -> str:
data = await quart.request.json data = await quart.request.json
target_enabled = data.get('target_enabled') target_enabled = data.get('target_enabled')
await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled) await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled)
return self.success() return self.success()
@self.route('/<author>/<plugin_name>/update', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) @self.route(
'/<author>/<plugin_name>/update',
methods=['POST'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _(author: str, plugin_name: str) -> str: async def _(author: str, plugin_name: str) -> str:
ctx = taskmgr.TaskContext.new() ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task( wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.update_plugin(plugin_name, task_context=ctx), self.ap.plugin_mgr.update_plugin(plugin_name, task_context=ctx),
kind="plugin-operation", kind='plugin-operation',
name=f"plugin-update-{plugin_name}", name=f'plugin-update-{plugin_name}',
label=f"更新插件 {plugin_name}", label=f'更新插件 {plugin_name}',
context=ctx context=ctx,
) )
return self.success(data={ return self.success(data={'task_id': wrapper.id})
'task_id': wrapper.id
})
@self.route('/<author>/<plugin_name>', methods=['GET', 'DELETE'], auth_type=group.AuthType.USER_TOKEN) @self.route(
'/<author>/<plugin_name>',
methods=['GET', 'DELETE'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _(author: str, plugin_name: str) -> str: async def _(author: str, plugin_name: str) -> str:
if quart.request.method == 'GET': if quart.request.method == 'GET':
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name) plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
if plugin is None: if plugin is None:
return self.http_status(404, -1, 'plugin not found') return self.http_status(404, -1, 'plugin not found')
return self.success(data={ return self.success(data={'plugin': plugin.model_dump()})
'plugin': plugin.model_dump()
})
elif quart.request.method == 'DELETE': elif quart.request.method == 'DELETE':
ctx = taskmgr.TaskContext.new() ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task( wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.uninstall_plugin(plugin_name, task_context=ctx), self.ap.plugin_mgr.uninstall_plugin(plugin_name, task_context=ctx),
kind="plugin-operation", kind='plugin-operation',
name=f'plugin-remove-{plugin_name}', name=f'plugin-remove-{plugin_name}',
label=f'删除插件 {plugin_name}', label=f'删除插件 {plugin_name}',
context=ctx context=ctx,
) )
return self.success(data={ return self.success(data={'task_id': wrapper.id})
'task_id': wrapper.id
})
@self.route('/<author>/<plugin_name>/config', methods=['GET', 'PUT'], auth_type=group.AuthType.USER_TOKEN) @self.route(
'/<author>/<plugin_name>/config',
methods=['GET', 'PUT'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _(author: str, plugin_name: str) -> quart.Response: async def _(author: str, plugin_name: str) -> quart.Response:
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name) plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
if plugin is None: if plugin is None:
return self.http_status(404, -1, 'plugin not found') return self.http_status(404, -1, 'plugin not found')
if quart.request.method == 'GET': if quart.request.method == 'GET':
return self.success(data={ return self.success(data={'config': plugin.plugin_config})
'config': plugin.plugin_config
})
elif quart.request.method == 'PUT': elif quart.request.method == 'PUT':
data = await quart.request.json data = await quart.request.json
@@ -89,7 +92,9 @@ class PluginsRouterGroup(group.RouterGroup):
await self.ap.plugin_mgr.reorder_plugins(data.get('plugins')) await self.ap.plugin_mgr.reorder_plugins(data.get('plugins'))
return self.success() return self.success()
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) @self.route(
'/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN
)
async def _() -> str: async def _() -> str:
data = await quart.request.json data = await quart.request.json
@@ -97,12 +102,10 @@ class PluginsRouterGroup(group.RouterGroup):
short_source_str = data['source'][-8:] short_source_str = data['source'][-8:]
wrapper = self.ap.task_mgr.create_user_task( wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx), self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx),
kind="plugin-operation", kind='plugin-operation',
name=f'plugin-install-github', name='plugin-install-github',
label=f'安装插件 ...{short_source_str}', label=f'安装插件 ...{short_source_str}',
context=ctx context=ctx,
) )
return self.success(data={ return self.success(data={'task_id': wrapper.id})
'task_id': wrapper.id
})
@@ -1,28 +1,23 @@
import quart import quart
import uuid
from ... import group from ... import group
from ......entity.persistence import model
@group.group_class('models/llm', '/api/v1/provider/models/llm') @group.group_class('models/llm', '/api/v1/provider/models/llm')
class LLMModelsRouterGroup(group.RouterGroup): class LLMModelsRouterGroup(group.RouterGroup):
async def initialize(self) -> None: async def initialize(self) -> None:
@self.route('', methods=['GET', 'POST']) @self.route('', methods=['GET', 'POST'])
async def _() -> str: async def _() -> str:
if quart.request.method == 'GET': if quart.request.method == 'GET':
return self.success(data={ return self.success(
'models': await self.ap.model_service.get_llm_models() data={'models': await self.ap.model_service.get_llm_models()}
}) )
elif quart.request.method == 'POST': elif quart.request.method == 'POST':
json_data = await quart.request.json json_data = await quart.request.json
model_uuid = await self.ap.model_service.create_llm_model(json_data) model_uuid = await self.ap.model_service.create_llm_model(json_data)
return self.success(data={ return self.success(data={'uuid': model_uuid})
'uuid': model_uuid
})
@self.route('/<model_uuid>', methods=['GET', 'DELETE']) @self.route('/<model_uuid>', methods=['GET', 'DELETE'])
async def _(model_uuid: str) -> str: async def _(model_uuid: str) -> str:
@@ -32,9 +27,7 @@ class LLMModelsRouterGroup(group.RouterGroup):
if model is None: if model is None:
return self.http_status(404, -1, 'model not found') return self.http_status(404, -1, 'model not found')
return self.success(data={ return self.success(data={'model': model})
'model': model
})
# elif quart.request.method == 'PUT': # elif quart.request.method == 'PUT':
# json_data = await quart.request.json # json_data = await quart.request.json
@@ -5,29 +5,31 @@ from ... import group
@group.group_class('provider/requesters', '/api/v1/provider/requesters') @group.group_class('provider/requesters', '/api/v1/provider/requesters')
class RequestersRouterGroup(group.RouterGroup): class RequestersRouterGroup(group.RouterGroup):
async def initialize(self) -> None: async def initialize(self) -> None:
@self.route('', methods=['GET']) @self.route('', methods=['GET'])
async def _() -> quart.Response: async def _() -> quart.Response:
return self.success(data={ return self.success(
'requesters': self.ap.model_mgr.get_available_requesters_info() data={'requesters': self.ap.model_mgr.get_available_requesters_info()}
}) )
@self.route('/<requester_name>', methods=['GET']) @self.route('/<requester_name>', methods=['GET'])
async def _(requester_name: str) -> quart.Response: async def _(requester_name: str) -> quart.Response:
requester_info = self.ap.model_mgr.get_available_requester_info_by_name(
requester_info = self.ap.model_mgr.get_available_requester_info_by_name(requester_name) requester_name
)
if requester_info is None: if requester_info is None:
return self.http_status(404, -1, 'requester not found') return self.http_status(404, -1, 'requester not found')
return self.success(data={ return self.success(data={'requester': requester_info})
'requester': requester_info
})
@self.route('/<requester_name>/icon', methods=['GET']) @self.route('/<requester_name>/icon', methods=['GET'])
async def _(requester_name: str) -> quart.Response: async def _(requester_name: str) -> quart.Response:
requester_manifest = self.ap.model_mgr.get_available_requester_manifest_by_name(requester_name) requester_manifest = (
self.ap.model_mgr.get_available_requester_manifest_by_name(
requester_name
)
)
if requester_manifest is None: if requester_manifest is None:
return self.http_status(404, -1, 'requester not found') return self.http_status(404, -1, 'requester not found')
+10 -12
View File
@@ -1,23 +1,21 @@
import quart
import asyncio
from .....core import app, taskmgr
from .. import group from .. import group
@group.group_class('stats', '/api/v1/stats') @group.group_class('stats', '/api/v1/stats')
class StatsRouterGroup(group.RouterGroup): class StatsRouterGroup(group.RouterGroup):
async def initialize(self) -> None: async def initialize(self) -> None:
@self.route('/basic', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) @self.route('/basic', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str: async def _() -> str:
conv_count = 0 conv_count = 0
for session in self.ap.sess_mgr.session_list: for session in self.ap.sess_mgr.session_list:
conv_count += len(session.conversations if session.conversations is not None else []) conv_count += len(
session.conversations if session.conversations is not None else []
)
return self.success(data={ return self.success(
'active_session_count': len(self.ap.sess_mgr.session_list), data={
'conversation_count': conv_count, 'active_session_count': len(self.ap.sess_mgr.session_list),
'query_count': self.ap.query_pool.query_id_counter, 'conversation_count': conv_count,
}) 'query_count': self.ap.query_pool.query_id_counter,
}
)
+18 -19
View File
@@ -1,42 +1,41 @@
import quart import quart
import asyncio
from .....core import app, taskmgr
from .. import group from .. import group
from .....utils import constants from .....utils import constants
@group.group_class('system', '/api/v1/system') @group.group_class('system', '/api/v1/system')
class SystemRouterGroup(group.RouterGroup): class SystemRouterGroup(group.RouterGroup):
async def initialize(self) -> None: async def initialize(self) -> None:
@self.route('/info', methods=['GET'], auth_type=group.AuthType.NONE) @self.route('/info', methods=['GET'], auth_type=group.AuthType.NONE)
async def _() -> str: async def _() -> str:
return self.success( return self.success(
data={ data={
"version": constants.semantic_version, 'version': constants.semantic_version,
"debug": constants.debug_mode, 'debug': constants.debug_mode,
"enabled_platform_count": len(self.ap.platform_mgr.get_running_adapters()) 'enabled_platform_count': len(
self.ap.platform_mgr.get_running_adapters()
),
} }
) )
@self.route('/tasks', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) @self.route('/tasks', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str: async def _() -> str:
task_type = quart.request.args.get("type") task_type = quart.request.args.get('type')
if task_type == '': if task_type == '':
task_type = None task_type = None
return self.success( return self.success(data=self.ap.task_mgr.get_tasks_dict(task_type))
data=self.ap.task_mgr.get_tasks_dict(task_type)
)
@self.route('/tasks/<task_id>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) @self.route(
'/tasks/<task_id>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN
)
async def _(task_id: str) -> str: async def _(task_id: str) -> str:
task = self.ap.task_mgr.get_task_by_id(int(task_id)) task = self.ap.task_mgr.get_task_by_id(int(task_id))
if task is None: if task is None:
return self.http_status(404, 404, "Task not found") return self.http_status(404, 404, 'Task not found')
return self.success(data=task.to_dict()) return self.success(data=task.to_dict())
@@ -44,20 +43,20 @@ class SystemRouterGroup(group.RouterGroup):
async def _() -> str: async def _() -> str:
json_data = await quart.request.json json_data = await quart.request.json
scope = json_data.get("scope") scope = json_data.get('scope')
await self.ap.reload( await self.ap.reload(scope=scope)
scope=scope
)
return self.success() return self.success()
@self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) @self.route(
'/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN
)
async def _() -> str: async def _() -> str:
if not constants.debug_mode: if not constants.debug_mode:
return self.http_status(403, 403, "Forbidden") return self.http_status(403, 403, 'Forbidden')
py_code = await quart.request.data py_code = await quart.request.data
ap = self.ap ap = self.ap
return self.success(data=exec(py_code, {"ap": ap})) return self.success(data=exec(py_code, {'ap': ap}))
+11 -14
View File
@@ -1,21 +1,18 @@
import quart import quart
import jwt
import argon2 import argon2
from .. import group from .. import group
from .....entity.persistence import user
@group.group_class('user', '/api/v1/user') @group.group_class('user', '/api/v1/user')
class UserRouterGroup(group.RouterGroup): class UserRouterGroup(group.RouterGroup):
async def initialize(self) -> None: async def initialize(self) -> None:
@self.route('/init', methods=['GET', 'POST'], auth_type=group.AuthType.NONE) @self.route('/init', methods=['GET', 'POST'], auth_type=group.AuthType.NONE)
async def _() -> str: async def _() -> str:
if quart.request.method == 'GET': if quart.request.method == 'GET':
return self.success(data={ return self.success(
'initialized': await self.ap.user_service.is_initialized() data={'initialized': await self.ap.user_service.is_initialized()}
}) )
if await self.ap.user_service.is_initialized(): if await self.ap.user_service.is_initialized():
return self.fail(1, '系统已初始化') return self.fail(1, '系统已初始化')
@@ -34,18 +31,18 @@ class UserRouterGroup(group.RouterGroup):
json_data = await quart.request.json json_data = await quart.request.json
try: try:
token = await self.ap.user_service.authenticate(json_data['user'], json_data['password']) token = await self.ap.user_service.authenticate(
json_data['user'], json_data['password']
)
except argon2.exceptions.VerifyMismatchError: except argon2.exceptions.VerifyMismatchError:
return self.fail(1, '用户名或密码错误') return self.fail(1, '用户名或密码错误')
return self.success(data={ return self.success(data={'token': token})
'token': token
})
@self.route('/check-token', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) @self.route(
'/check-token', methods=['GET'], auth_type=group.AuthType.USER_TOKEN
)
async def _(user_email: str) -> str: async def _(user_email: str) -> str:
token = await self.ap.user_service.generate_jwt_token(user_email) token = await self.ap.user_service.generate_jwt_token(user_email)
return self.success(data={ return self.success(data={'token': token})
'token': token
})
+44 -43
View File
@@ -7,15 +7,19 @@ import quart
import quart_cors import quart_cors
from ....core import app, entities as core_entities from ....core import app, entities as core_entities
from ....utils import importutil
from .groups import logs, system, plugins, stats, user, pipelines from . import groups
from .groups.provider import models, requesters
from .groups.platform import bots, adapters
from . import group from . import group
from .groups import provider as groups_provider
from .groups import platform as groups_platform
importutil.import_modules_in_pkg(groups)
importutil.import_modules_in_pkg(groups_provider)
importutil.import_modules_in_pkg(groups_platform)
class HTTPController: class HTTPController:
ap: app.Application ap: app.Application
quart_app: quart.Quart quart_app: quart.Quart
@@ -23,7 +27,7 @@ class HTTPController:
def __init__(self, ap: app.Application) -> None: def __init__(self, ap: app.Application) -> None:
self.ap = ap self.ap = ap
self.quart_app = quart.Quart(__name__) self.quart_app = quart.Quart(__name__)
quart_cors.cors(self.quart_app, allow_origin="*") quart_cors.cors(self.quart_app, allow_origin='*')
async def initialize(self) -> None: async def initialize(self) -> None:
await self.register_routes() await self.register_routes()
@@ -37,11 +41,9 @@ class HTTPController:
async def exception_handler(*args, **kwargs): async def exception_handler(*args, **kwargs):
try: try:
await self.quart_app.run_task( await self.quart_app.run_task(*args, **kwargs)
*args, **kwargs
)
except Exception as e: except Exception as e:
self.ap.logger.error(f"启动 HTTP 服务失败: {e}") self.ap.logger.error(f'启动 HTTP 服务失败: {e}')
self.ap.task_mgr.create_task( self.ap.task_mgr.create_task(
exception_handler( exception_handler(
@@ -49,63 +51,62 @@ class HTTPController:
port=self.ap.instance_config.data['api']['port'], port=self.ap.instance_config.data['api']['port'],
shutdown_trigger=shutdown_trigger_placeholder, shutdown_trigger=shutdown_trigger_placeholder,
), ),
name="http-api-quart", name='http-api-quart',
scopes=[core_entities.LifecycleControlScope.APPLICATION], scopes=[core_entities.LifecycleControlScope.APPLICATION],
) )
# await asyncio.sleep(5) # await asyncio.sleep(5)
async def register_routes(self) -> None: async def register_routes(self) -> None:
@self.quart_app.route('/healthz')
@self.quart_app.route("/healthz")
async def healthz(): async def healthz():
return {"code": 0, "msg": "ok"} return {'code': 0, 'msg': 'ok'}
for g in group.preregistered_groups: for g in group.preregistered_groups:
ginst = g(self.ap, self.quart_app) ginst = g(self.ap, self.quart_app)
await ginst.initialize() await ginst.initialize()
frontend_path = "web/out" frontend_path = 'web/out'
@self.quart_app.route("/") @self.quart_app.route('/')
async def index(): async def index():
return await quart.send_from_directory(frontend_path, "index.html", mimetype="text/html") return await quart.send_from_directory(
frontend_path, 'index.html', mimetype='text/html'
)
@self.quart_app.route("/<path:path>") @self.quart_app.route('/<path:path>')
async def static_file(path: str): async def static_file(path: str):
if not os.path.exists(os.path.join(frontend_path, path)): if not os.path.exists(os.path.join(frontend_path, path)):
if os.path.exists(os.path.join(frontend_path, path+".html")): if os.path.exists(os.path.join(frontend_path, path + '.html')):
path += '.html' path += '.html'
else: else:
return await quart.send_from_directory(frontend_path, '404.html') return await quart.send_from_directory(frontend_path, '404.html')
mimetype = None mimetype = None
if path.endswith(".html"): if path.endswith('.html'):
mimetype = "text/html" mimetype = 'text/html'
elif path.endswith(".js"): elif path.endswith('.js'):
mimetype = "application/javascript" mimetype = 'application/javascript'
elif path.endswith(".css"): elif path.endswith('.css'):
mimetype = "text/css" mimetype = 'text/css'
elif path.endswith(".png"): elif path.endswith('.png'):
mimetype = "image/png" mimetype = 'image/png'
elif path.endswith(".jpg"): elif path.endswith('.jpg'):
mimetype = "image/jpeg" mimetype = 'image/jpeg'
elif path.endswith(".jpeg"): elif path.endswith('.jpeg'):
mimetype = "image/jpeg" mimetype = 'image/jpeg'
elif path.endswith(".gif"): elif path.endswith('.gif'):
mimetype = "image/gif" mimetype = 'image/gif'
elif path.endswith(".svg"): elif path.endswith('.svg'):
mimetype = "image/svg+xml" mimetype = 'image/svg+xml'
elif path.endswith(".ico"): elif path.endswith('.ico'):
mimetype = "image/x-icon" mimetype = 'image/x-icon'
elif path.endswith(".json"): elif path.endswith('.json'):
mimetype = "application/json" mimetype = 'application/json'
elif path.endswith(".txt"): elif path.endswith('.txt'):
mimetype = "text/plain" mimetype = 'text/plain'
return await quart.send_from_directory( return await quart.send_from_directory(
frontend_path, frontend_path, path, mimetype=mimetype
path,
mimetype=mimetype
) )
+17 -9
View File
@@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import uuid import uuid
import datetime
import sqlalchemy import sqlalchemy
from ....core import app from ....core import app
@@ -33,7 +32,9 @@ class BotService:
async def get_bot(self, bot_uuid: str) -> dict | None: async def get_bot(self, bot_uuid: str) -> dict | None:
"""获取机器人""" """获取机器人"""
result = await self.ap.persistence_mgr.execute_async( result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid) sqlalchemy.select(persistence_bot.Bot).where(
persistence_bot.Bot.uuid == bot_uuid
)
) )
bot = result.first() bot = result.first()
@@ -50,7 +51,9 @@ class BotService:
# checkout the default pipeline # checkout the default pipeline
result = await self.ap.persistence_mgr.execute_async( result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.is_default == True) sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.is_default == True
)
) )
pipeline = result.first() pipeline = result.first()
if pipeline is not None: if pipeline is not None:
@@ -75,16 +78,21 @@ class BotService:
# set use_pipeline_name # set use_pipeline_name
if 'use_pipeline_uuid' in bot_data: if 'use_pipeline_uuid' in bot_data:
result = await self.ap.persistence_mgr.execute_async( result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == bot_data['use_pipeline_uuid']) sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.uuid
== bot_data['use_pipeline_uuid']
)
) )
pipeline = result.first() pipeline = result.first()
if pipeline is not None: if pipeline is not None:
bot_data['use_pipeline_name'] = pipeline.name bot_data['use_pipeline_name'] = pipeline.name
else: else:
raise Exception("Pipeline not found") raise Exception('Pipeline not found')
await self.ap.persistence_mgr.execute_async( await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_bot.Bot).values(bot_data).where(persistence_bot.Bot.uuid == bot_uuid) sqlalchemy.update(persistence_bot.Bot)
.values(bot_data)
.where(persistence_bot.Bot.uuid == bot_uuid)
) )
await self.ap.platform_mgr.remove_bot(bot_uuid) await self.ap.platform_mgr.remove_bot(bot_uuid)
@@ -100,7 +108,7 @@ class BotService:
"""删除机器人""" """删除机器人"""
await self.ap.platform_mgr.remove_bot(bot_uuid) await self.ap.platform_mgr.remove_bot(bot_uuid)
await self.ap.persistence_mgr.execute_async( await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid) sqlalchemy.delete(persistence_bot.Bot).where(
persistence_bot.Bot.uuid == bot_uuid
)
) )
+17 -14
View File
@@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import uuid import uuid
import datetime
import sqlalchemy import sqlalchemy
from ....core import app from ....core import app
@@ -10,7 +9,6 @@ from ....entity.persistence import pipeline as persistence_pipeline
class ModelsService: class ModelsService:
ap: app.Application ap: app.Application
def __init__(self, ap: app.Application) -> None: def __init__(self, ap: app.Application) -> None:
@@ -28,13 +26,10 @@ class ModelsService:
] ]
async def create_llm_model(self, model_data: dict) -> str: async def create_llm_model(self, model_data: dict) -> str:
model_data['uuid'] = str(uuid.uuid4()) model_data['uuid'] = str(uuid.uuid4())
await self.ap.persistence_mgr.execute_async( await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_model.LLMModel).values( sqlalchemy.insert(persistence_model.LLMModel).values(**model_data)
**model_data
)
) )
llm_model = await self.get_llm_model(model_data['uuid']) llm_model = await self.get_llm_model(model_data['uuid'])
@@ -43,22 +38,24 @@ class ModelsService:
# check if default pipeline has no model bound # check if default pipeline has no model bound
result = await self.ap.persistence_mgr.execute_async( result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.is_default == True) sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.is_default == True
)
) )
pipeline = result.first() pipeline = result.first()
if pipeline is not None and pipeline.config['ai']['local-agent']['model'] == '': if pipeline is not None and pipeline.config['ai']['local-agent']['model'] == '':
pipeline_config = pipeline.config pipeline_config = pipeline.config
pipeline_config['ai']['local-agent']['model'] = model_data['uuid'] pipeline_config['ai']['local-agent']['model'] = model_data['uuid']
pipeline_data = { pipeline_data = {'config': pipeline_config}
"config": pipeline_config
}
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data) await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
return model_data['uuid'] return model_data['uuid']
async def get_llm_model(self, model_uuid: str) -> dict | None: async def get_llm_model(self, model_uuid: str) -> dict | None:
result = await self.ap.persistence_mgr.execute_async( result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid) sqlalchemy.select(persistence_model.LLMModel).where(
persistence_model.LLMModel.uuid == model_uuid
)
) )
model = result.first() model = result.first()
@@ -66,14 +63,18 @@ class ModelsService:
if model is None: if model is None:
return None return None
return self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) return self.ap.persistence_mgr.serialize_model(
persistence_model.LLMModel, model
)
async def update_llm_model(self, model_uuid: str, model_data: dict) -> None: async def update_llm_model(self, model_uuid: str, model_data: dict) -> None:
if 'uuid' in model_data: if 'uuid' in model_data:
del model_data['uuid'] del model_data['uuid']
await self.ap.persistence_mgr.execute_async( await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid).values(**model_data) sqlalchemy.update(persistence_model.LLMModel)
.where(persistence_model.LLMModel.uuid == model_uuid)
.values(**model_data)
) )
await self.ap.model_mgr.remove_llm_model(model_uuid) await self.ap.model_mgr.remove_llm_model(model_uuid)
@@ -84,7 +85,9 @@ class ModelsService:
async def delete_llm_model(self, model_uuid: str) -> None: async def delete_llm_model(self, model_uuid: str) -> None:
await self.ap.persistence_mgr.execute_async( await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid) sqlalchemy.delete(persistence_model.LLMModel).where(
persistence_model.LLMModel.uuid == model_uuid
)
) )
await self.ap.model_mgr.remove_llm_model(model_uuid) await self.ap.model_mgr.remove_llm_model(model_uuid)
+34 -21
View File
@@ -2,7 +2,6 @@ from __future__ import annotations
import uuid import uuid
import json import json
import datetime
import sqlalchemy import sqlalchemy
from ....core import app from ....core import app
@@ -10,18 +9,18 @@ from ....entity.persistence import pipeline as persistence_pipeline
default_stage_order = [ default_stage_order = [
"GroupRespondRuleCheckStage", # 群响应规则检查 'GroupRespondRuleCheckStage', # 群响应规则检查
"BanSessionCheckStage", # 封禁会话检查 'BanSessionCheckStage', # 封禁会话检查
"PreContentFilterStage", # 内容过滤前置阶段 'PreContentFilterStage', # 内容过滤前置阶段
"PreProcessor", # 预处理器 'PreProcessor', # 预处理器
"ConversationMessageTruncator", # 会话消息截断器 'ConversationMessageTruncator', # 会话消息截断器
"RequireRateLimitOccupancy", # 请求速率限制占用 'RequireRateLimitOccupancy', # 请求速率限制占用
"MessageProcessor", # 处理器 'MessageProcessor', # 处理器
"ReleaseRateLimitOccupancy", # 释放速率限制占用 'ReleaseRateLimitOccupancy', # 释放速率限制占用
"PostContentFilterStage", # 内容过滤后置阶段 'PostContentFilterStage', # 内容过滤后置阶段
"ResponseWrapper", # 响应包装器 'ResponseWrapper', # 响应包装器
"LongTextProcessStage", # 长文本处理 'LongTextProcessStage', # 长文本处理
"SendResponseBackStage", # 发送响应 'SendResponseBackStage', # 发送响应
] ]
@@ -36,7 +35,7 @@ class PipelineService:
self.ap.pipeline_config_meta_trigger.data, self.ap.pipeline_config_meta_trigger.data,
self.ap.pipeline_config_meta_safety.data, self.ap.pipeline_config_meta_safety.data,
self.ap.pipeline_config_meta_ai.data, self.ap.pipeline_config_meta_ai.data,
self.ap.pipeline_config_meta_output.data self.ap.pipeline_config_meta_output.data,
] ]
async def get_pipelines(self) -> list[dict]: async def get_pipelines(self) -> list[dict]:
@@ -46,13 +45,17 @@ class PipelineService:
pipelines = result.all() pipelines = result.all()
return [ return [
self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline) self.ap.persistence_mgr.serialize_model(
persistence_pipeline.LegacyPipeline, pipeline
)
for pipeline in pipelines for pipeline in pipelines
] ]
async def get_pipeline(self, pipeline_uuid: str) -> dict | None: async def get_pipeline(self, pipeline_uuid: str) -> dict | None:
result = await self.ap.persistence_mgr.execute_async( result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid) sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid
)
) )
pipeline = result.first() pipeline = result.first()
@@ -60,17 +63,23 @@ class PipelineService:
if pipeline is None: if pipeline is None:
return None return None
return self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline) return self.ap.persistence_mgr.serialize_model(
persistence_pipeline.LegacyPipeline, pipeline
)
async def create_pipeline(self, pipeline_data: dict, default: bool = False) -> str: async def create_pipeline(self, pipeline_data: dict, default: bool = False) -> str:
pipeline_data['uuid'] = str(uuid.uuid4()) pipeline_data['uuid'] = str(uuid.uuid4())
pipeline_data['for_version'] = self.ap.ver_mgr.get_current_version() pipeline_data['for_version'] = self.ap.ver_mgr.get_current_version()
pipeline_data['stages'] = default_stage_order.copy() pipeline_data['stages'] = default_stage_order.copy()
pipeline_data['is_default'] = default pipeline_data['is_default'] = default
pipeline_data['config'] = json.load(open('templates/default-pipeline-config.json', 'r', encoding='utf-8')) pipeline_data['config'] = json.load(
open('templates/default-pipeline-config.json', 'r', encoding='utf-8')
)
await self.ap.persistence_mgr.execute_async( await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(**pipeline_data) sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(
**pipeline_data
)
) )
pipeline = await self.get_pipeline(pipeline_data['uuid']) pipeline = await self.get_pipeline(pipeline_data['uuid'])
@@ -90,7 +99,9 @@ class PipelineService:
del pipeline_data['is_default'] del pipeline_data['is_default']
await self.ap.persistence_mgr.execute_async( await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid).values(**pipeline_data) sqlalchemy.update(persistence_pipeline.LegacyPipeline)
.where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid)
.values(**pipeline_data)
) )
await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid) await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid)
@@ -101,6 +112,8 @@ class PipelineService:
async def delete_pipeline(self, pipeline_uuid: str) -> None: async def delete_pipeline(self, pipeline_uuid: str) -> None:
await self.ap.persistence_mgr.execute_async( await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid) sqlalchemy.delete(persistence_pipeline.LegacyPipeline).where(
persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid
)
) )
await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid) await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid)
+3 -5
View File
@@ -11,7 +11,6 @@ from ....utils import constants
class UserService: class UserService:
ap: app.Application ap: app.Application
def __init__(self, ap: app.Application) -> None: def __init__(self, ap: app.Application) -> None:
@@ -32,8 +31,7 @@ class UserService:
await self.ap.persistence_mgr.execute_async( await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(user.User).values( sqlalchemy.insert(user.User).values(
user=user_email, user=user_email, password=hashed_password
password=hashed_password
) )
) )
@@ -61,8 +59,8 @@ class UserService:
payload = { payload = {
'user': user_email, 'user': user_email,
'iss': 'LangBot-'+constants.edition, 'iss': 'LangBot-' + constants.edition,
'exp': datetime.datetime.now() + datetime.timedelta(seconds=jwt_expire) 'exp': datetime.datetime.now() + datetime.timedelta(seconds=jwt_expire),
} }
return jwt.encode(payload, jwt_secret, algorithm='HS256') return jwt.encode(payload, jwt_secret, algorithm='HS256')
+8 -10
View File
@@ -3,11 +3,9 @@ from __future__ import annotations
import abc import abc
import uuid import uuid
import json import json
import logging
import asyncio import asyncio
import aiohttp import aiohttp
import requests
from ...core import app, entities as core_entities from ...core import app, entities as core_entities
@@ -38,22 +36,22 @@ class APIGroup(metaclass=abc.ABCMeta):
""" """
执行请求 执行请求
""" """
self._runtime_info["account_id"] = "-1" self._runtime_info['account_id'] = '-1'
url = self.prefix + path url = self.prefix + path
data = json.dumps(data) data = json.dumps(data)
headers["Content-Type"] = "application/json" headers['Content-Type'] = 'application/json'
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.request( async with session.request(
method, url, data=data, params=params, headers=headers, **kwargs method, url, data=data, params=params, headers=headers, **kwargs
) as resp: ) as resp:
self.ap.logger.debug("data: %s", data) self.ap.logger.debug('data: %s', data)
self.ap.logger.debug("ret: %s", await resp.text()) self.ap.logger.debug('ret: %s', await resp.text())
except Exception as e: except Exception as e:
self.ap.logger.debug(f"上报失败: {e}") self.ap.logger.debug(f'上报失败: {e}')
async def do( async def do(
self, self,
@@ -68,8 +66,8 @@ class APIGroup(metaclass=abc.ABCMeta):
return self.ap.task_mgr.create_task( return self.ap.task_mgr.create_task(
self._do(method, path, data, params, headers, **kwargs), self._do(method, path, data, params, headers, **kwargs),
kind="telemetry-operation", kind='telemetry-operation',
name=f"{method} {path}", name=f'{method} {path}',
scopes=[core_entities.LifecycleControlScope.APPLICATION], scopes=[core_entities.LifecycleControlScope.APPLICATION],
).task ).task
@@ -80,7 +78,7 @@ class APIGroup(metaclass=abc.ABCMeta):
def basic_info(self): def basic_info(self):
"""获取基本信息""" """获取基本信息"""
basic_info = APIGroup._basic_info.copy() basic_info = APIGroup._basic_info.copy()
basic_info["rid"] = self.gen_rid() basic_info['rid'] = self.gen_rid()
return basic_info return basic_info
def runtime_info(self): def runtime_info(self):
+18 -18
View File
@@ -9,7 +9,7 @@ class V2MainDataAPI(apigroup.APIGroup):
def __init__(self, prefix: str, ap: app.Application): def __init__(self, prefix: str, ap: app.Application):
self.ap = ap self.ap = ap
super().__init__(prefix+"/main", ap) super().__init__(prefix + '/main', ap)
async def do(self, *args, **kwargs): async def do(self, *args, **kwargs):
if not self.ap.instance_config.data['telemetry']['report']: if not self.ap.instance_config.data['telemetry']['report']:
@@ -25,17 +25,17 @@ class V2MainDataAPI(apigroup.APIGroup):
): ):
"""提交更新记录""" """提交更新记录"""
return await self.do( return await self.do(
"POST", 'POST',
"/update", '/update',
data={ data={
"basic": self.basic_info(), 'basic': self.basic_info(),
"update_info": { 'update_info': {
"spent_seconds": spent_seconds, 'spent_seconds': spent_seconds,
"infer_reason": infer_reason, 'infer_reason': infer_reason,
"old_version": old_version, 'old_version': old_version,
"new_version": new_version, 'new_version': new_version,
} },
} },
) )
async def post_announcement_showed( async def post_announcement_showed(
@@ -44,12 +44,12 @@ class V2MainDataAPI(apigroup.APIGroup):
): ):
"""提交公告已阅""" """提交公告已阅"""
return await self.do( return await self.do(
"POST", 'POST',
"/announcement", '/announcement',
data={ data={
"basic": self.basic_info(), 'basic': self.basic_info(),
"announcement_info": { 'announcement_info': {
"ids": ids, 'ids': ids,
} },
} },
) )
+22 -28
View File
@@ -9,39 +9,33 @@ class V2PluginDataAPI(apigroup.APIGroup):
def __init__(self, prefix: str, ap: app.Application): def __init__(self, prefix: str, ap: app.Application):
self.ap = ap self.ap = ap
super().__init__(prefix+"/plugin", ap) super().__init__(prefix + '/plugin', ap)
async def do(self, *args, **kwargs): async def do(self, *args, **kwargs):
if not self.ap.instance_config.data['telemetry']['report']: if not self.ap.instance_config.data['telemetry']['report']:
return None return None
return await super().do(*args, **kwargs) return await super().do(*args, **kwargs)
async def post_install_record( async def post_install_record(self, plugin: dict):
self,
plugin: dict
):
"""提交插件安装记录""" """提交插件安装记录"""
return await self.do( return await self.do(
"POST", 'POST',
"/install", '/install',
data={ data={
"basic": self.basic_info(), 'basic': self.basic_info(),
"plugin": plugin, 'plugin': plugin,
} },
) )
async def post_remove_record( async def post_remove_record(self, plugin: dict):
self,
plugin: dict
):
"""提交插件卸载记录""" """提交插件卸载记录"""
return await self.do( return await self.do(
"POST", 'POST',
"/remove", '/remove',
data={ data={
"basic": self.basic_info(), 'basic': self.basic_info(),
"plugin": plugin, 'plugin': plugin,
} },
) )
async def post_update_record( async def post_update_record(
@@ -52,14 +46,14 @@ class V2PluginDataAPI(apigroup.APIGroup):
): ):
"""提交插件更新记录""" """提交插件更新记录"""
return await self.do( return await self.do(
"POST", 'POST',
"/update", '/update',
data={ data={
"basic": self.basic_info(), 'basic': self.basic_info(),
"plugin": plugin, 'plugin': plugin,
"update_info": { 'update_info': {
"old_version": old_version, 'old_version': old_version,
"new_version": new_version, 'new_version': new_version,
} },
} },
) )
+34 -35
View File
@@ -9,7 +9,7 @@ class V2UsageDataAPI(apigroup.APIGroup):
def __init__(self, prefix: str, ap: app.Application): def __init__(self, prefix: str, ap: app.Application):
self.ap = ap self.ap = ap
super().__init__(prefix+"/usage", ap) super().__init__(prefix + '/usage', ap)
async def do(self, *args, **kwargs): async def do(self, *args, **kwargs):
if not self.ap.instance_config.data['telemetry']['report']: if not self.ap.instance_config.data['telemetry']['report']:
@@ -28,23 +28,23 @@ class V2UsageDataAPI(apigroup.APIGroup):
): ):
"""提交请求记录""" """提交请求记录"""
return await self.do( return await self.do(
"POST", 'POST',
"/query", '/query',
data={ data={
"basic": self.basic_info(), 'basic': self.basic_info(),
"runtime": self.runtime_info(), 'runtime': self.runtime_info(),
"session_info": { 'session_info': {
"type": session_type, 'type': session_type,
"id": session_id, 'id': session_id,
}, },
"query_info": { 'query_info': {
"ability_provider": query_ability_provider, 'ability_provider': query_ability_provider,
"usage": usage, 'usage': usage,
"model_name": model_name, 'model_name': model_name,
"response_seconds": response_seconds, 'response_seconds': response_seconds,
"retry_times": retry_times, 'retry_times': retry_times,
} },
} },
) )
async def post_event_record( async def post_event_record(
@@ -54,16 +54,16 @@ class V2UsageDataAPI(apigroup.APIGroup):
): ):
"""提交事件触发记录""" """提交事件触发记录"""
return await self.do( return await self.do(
"POST", 'POST',
"/event", '/event',
data={ data={
"basic": self.basic_info(), 'basic': self.basic_info(),
"runtime": self.runtime_info(), 'runtime': self.runtime_info(),
"plugins": plugins, 'plugins': plugins,
"event_info": { 'event_info': {
"name": event_name, 'name': event_name,
} },
} },
) )
async def post_function_record( async def post_function_record(
@@ -74,15 +74,14 @@ class V2UsageDataAPI(apigroup.APIGroup):
): ):
"""提交内容函数使用记录""" """提交内容函数使用记录"""
return await self.do( return await self.do(
"POST", 'POST',
"/function", '/function',
data={ data={
"basic": self.basic_info(), 'basic': self.basic_info(),
"plugin": plugin, 'plugin': plugin,
"function_info": { 'function_info': {
"name": function_name, 'name': function_name,
"description": function_description, 'description': function_description,
} },
} },
) )
+8 -3
View File
@@ -21,10 +21,16 @@ class V2CenterAPI:
plugin: plugin.V2PluginDataAPI = None plugin: plugin.V2PluginDataAPI = None
"""插件 API 组""" """插件 API 组"""
def __init__(self, ap: app.Application, backend_url: str, basic_info: dict = None, runtime_info: dict = None): def __init__(
self,
ap: app.Application,
backend_url: str,
basic_info: dict = None,
runtime_info: dict = None,
):
"""初始化""" """初始化"""
logging.debug("basic_info: %s, runtime_info: %s", basic_info, runtime_info) logging.debug('basic_info: %s, runtime_info: %s', basic_info, runtime_info)
apigroup.APIGroup._basic_info = basic_info apigroup.APIGroup._basic_info = basic_info
apigroup.APIGroup._runtime_info = runtime_info apigroup.APIGroup._runtime_info = runtime_info
@@ -32,4 +38,3 @@ class V2CenterAPI:
self.main = main.V2MainDataAPI(backend_url, ap) self.main = main.V2MainDataAPI(backend_url, ap)
self.usage = usage.V2UsageDataAPI(backend_url, ap) self.usage = usage.V2UsageDataAPI(backend_url, ap)
self.plugin = plugin.V2PluginDataAPI(backend_url, ap) self.plugin = plugin.V2PluginDataAPI(backend_url, ap)
+16 -12
View File
@@ -16,6 +16,7 @@ identifier = {
HOST_ID_FILE = os.path.expanduser('~/.langbot/host_id.json') HOST_ID_FILE = os.path.expanduser('~/.langbot/host_id.json')
INSTANCE_ID_FILE = 'data/labels/instance_id.json' INSTANCE_ID_FILE = 'data/labels/instance_id.json'
def init(): def init():
global identifier global identifier
@@ -23,14 +24,11 @@ def init():
os.mkdir(os.path.expanduser('~/.langbot')) os.mkdir(os.path.expanduser('~/.langbot'))
if not os.path.exists(HOST_ID_FILE): if not os.path.exists(HOST_ID_FILE):
new_host_id = 'host_'+str(uuid.uuid4()) new_host_id = 'host_' + str(uuid.uuid4())
new_host_create_ts = int(time.time()) new_host_create_ts = int(time.time())
with open(HOST_ID_FILE, 'w') as f: with open(HOST_ID_FILE, 'w') as f:
json.dump({ json.dump({'host_id': new_host_id, 'host_create_ts': new_host_create_ts}, f)
'host_id': new_host_id,
'host_create_ts': new_host_create_ts
}, f)
identifier['host_id'] = new_host_id identifier['host_id'] = new_host_id
identifier['host_create_ts'] = new_host_create_ts identifier['host_create_ts'] = new_host_create_ts
@@ -52,19 +50,24 @@ def init():
with open(INSTANCE_ID_FILE, 'r') as f: with open(INSTANCE_ID_FILE, 'r') as f:
instance_id = json.load(f) instance_id = json.load(f)
if instance_id['host_id'] != identifier['host_id']: # 如果实例 id 不是当前主机的,删除 if (
instance_id['host_id'] != identifier['host_id']
): # 如果实例 id 不是当前主机的,删除
os.remove(INSTANCE_ID_FILE) os.remove(INSTANCE_ID_FILE)
if not os.path.exists(INSTANCE_ID_FILE): if not os.path.exists(INSTANCE_ID_FILE):
new_instance_id = 'instance_'+str(uuid.uuid4()) new_instance_id = 'instance_' + str(uuid.uuid4())
new_instance_create_ts = int(time.time()) new_instance_create_ts = int(time.time())
with open(INSTANCE_ID_FILE, 'w') as f: with open(INSTANCE_ID_FILE, 'w') as f:
json.dump({ json.dump(
'host_id': identifier['host_id'], {
'instance_id': new_instance_id, 'host_id': identifier['host_id'],
'instance_create_ts': new_instance_create_ts 'instance_id': new_instance_id,
}, f) 'instance_create_ts': new_instance_create_ts,
},
f,
)
identifier['instance_id'] = new_instance_id identifier['instance_id'] = new_instance_id
identifier['instance_create_ts'] = new_instance_create_ts identifier['instance_create_ts'] = new_instance_create_ts
@@ -80,6 +83,7 @@ def init():
identifier['instance_id'] = loaded_instance_id identifier['instance_id'] = loaded_instance_id
identifier['instance_create_ts'] = loaded_instance_create_ts identifier['instance_create_ts'] = loaded_instance_create_ts
def print_out(): def print_out():
global identifier global identifier
print(identifier) print(identifier)
+28 -29
View File
@@ -3,17 +3,17 @@ from __future__ import annotations
import typing import typing
from ..core import app, entities as core_entities from ..core import app, entities as core_entities
from ..provider import entities as llm_entities
from . import entities, operator, errors from . import entities, operator, errors
from ..config import manager as cfg_mgr from ..utils import importutil
# 引入所有算子以便注册 # 引入所有算子以便注册
from .operators import func, plugin, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model from . import operators
importutil.import_modules_in_pkg(operators)
class CommandManager: class CommandManager:
"""命令管理器 """命令管理器"""
"""
ap: app.Application ap: app.Application
@@ -26,7 +26,6 @@ class CommandManager:
self.ap = ap self.ap = ap
async def initialize(self): async def initialize(self):
# 设置各个类的路径 # 设置各个类的路径
def set_path(cls: operator.CommandOperator, ancestors: list[str]): def set_path(cls: operator.CommandOperator, ancestors: list[str]):
cls.path = '.'.join(ancestors + [cls.name]) cls.path = '.'.join(ancestors + [cls.name])
@@ -41,14 +40,18 @@ class CommandManager:
# 应用命令权限配置 # 应用命令权限配置
for cls in operator.preregistered_operators: for cls in operator.preregistered_operators:
if cls.path in self.ap.instance_config.data['command']['privilege']: if cls.path in self.ap.instance_config.data['command']['privilege']:
cls.lowest_privilege = self.ap.instance_config.data['command']['privilege'][cls.path] cls.lowest_privilege = self.ap.instance_config.data['command'][
'privilege'
][cls.path]
# 实例化所有类 # 实例化所有类
self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators] self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators]
# 设置所有类的子节点 # 设置所有类的子节点
for cmd in self.cmd_list: for cmd in self.cmd_list:
cmd.children = [child for child in self.cmd_list if child.parent_class == cmd.__class__] cmd.children = [
child for child in self.cmd_list if child.parent_class == cmd.__class__
]
# 初始化所有类 # 初始化所有类
for cmd in self.cmd_list: for cmd in self.cmd_list:
@@ -58,27 +61,25 @@ class CommandManager:
self, self,
context: entities.ExecuteContext, context: entities.ExecuteContext,
operator_list: list[operator.CommandOperator], operator_list: list[operator.CommandOperator],
operator: operator.CommandOperator = None operator: operator.CommandOperator = None,
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令 """执行命令"""
"""
found = False found = False
if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名 if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
for oper in operator_list: for oper in operator_list:
if (context.crt_params[0] == oper.name \ if (
or context.crt_params[0] in oper.alias) \ context.crt_params[0] == oper.name
and (oper.parent_class is None or oper.parent_class == operator.__class__): or context.crt_params[0] in oper.alias
) and (
oper.parent_class is None or oper.parent_class == operator.__class__
):
found = True found = True
context.crt_command = context.crt_params[0] context.crt_command = context.crt_params[0]
context.crt_params = context.crt_params[1:] context.crt_params = context.crt_params[1:]
async for ret in self._execute( async for ret in self._execute(context, oper.children, oper):
context,
oper.children,
oper
):
yield ret yield ret
break break
@@ -96,19 +97,20 @@ class CommandManager:
async for ret in operator.execute(context): async for ret in operator.execute(context):
yield ret yield ret
async def execute( async def execute(
self, self,
command_text: str, command_text: str,
query: core_entities.Query, query: core_entities.Query,
session: core_entities.Session session: core_entities.Session,
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令 """执行命令"""
"""
privilege = 1 privilege = 1
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']: if (
f'{query.launcher_type.value}_{query.launcher_id}'
in self.ap.instance_config.data['admins']
):
privilege = 2 privilege = 2
ctx = entities.ExecuteContext( ctx = entities.ExecuteContext(
@@ -119,11 +121,8 @@ class CommandManager:
crt_command='', crt_command='',
params=command_text.split(' '), params=command_text.split(' '),
crt_params=command_text.split(' '), crt_params=command_text.split(' '),
privilege=privilege privilege=privilege,
) )
async for ret in self._execute( async for ret in self._execute(ctx, self.cmd_list):
ctx,
self.cmd_list
):
yield ret yield ret
+5 -7
View File
@@ -4,14 +4,13 @@ import typing
import pydantic.v1 as pydantic import pydantic.v1 as pydantic
from ..core import app, entities as core_entities from ..core import entities as core_entities
from . import errors, operator from . import errors
from ..platform.types import message as platform_message from ..platform.types import message as platform_message
class CommandReturn(pydantic.BaseModel): class CommandReturn(pydantic.BaseModel):
"""命令返回值 """命令返回值"""
"""
text: typing.Optional[str] = None text: typing.Optional[str] = None
"""文本 """文本
@@ -24,7 +23,7 @@ class CommandReturn(pydantic.BaseModel):
"""图片链接 """图片链接
""" """
error: typing.Optional[errors.CommandError]= None error: typing.Optional[errors.CommandError] = None
"""错误 """错误
""" """
@@ -33,8 +32,7 @@ class CommandReturn(pydantic.BaseModel):
class ExecuteContext(pydantic.BaseModel): class ExecuteContext(pydantic.BaseModel):
"""单次命令执行上下文 """单次命令执行上下文"""
"""
query: core_entities.Query query: core_entities.Query
"""本次消息的请求对象""" """本次消息的请求对象"""
+4 -11
View File
@@ -1,7 +1,4 @@
class CommandError(Exception): class CommandError(Exception):
def __init__(self, message: str = None): def __init__(self, message: str = None):
self.message = message self.message = message
@@ -10,24 +7,20 @@ class CommandError(Exception):
class CommandNotFoundError(CommandError): class CommandNotFoundError(CommandError):
def __init__(self, message: str = None): def __init__(self, message: str = None):
super().__init__("未知命令: "+message) super().__init__('未知命令: ' + message)
class CommandPrivilegeError(CommandError): class CommandPrivilegeError(CommandError):
def __init__(self, message: str = None): def __init__(self, message: str = None):
super().__init__("权限不足: "+message) super().__init__('权限不足: ' + message)
class ParamNotEnoughError(CommandError): class ParamNotEnoughError(CommandError):
def __init__(self, message: str = None): def __init__(self, message: str = None):
super().__init__("参数不足: "+message) super().__init__('参数不足: ' + message)
class CommandOperationError(CommandError): class CommandOperationError(CommandError):
def __init__(self, message: str = None): def __init__(self, message: str = None):
super().__init__("操作失败: "+message) super().__init__('操作失败: ' + message)
+5 -6
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
import typing import typing
import abc import abc
from ..core import app, entities as core_entities from ..core import app
from . import entities from . import entities
@@ -13,11 +13,11 @@ preregistered_operators: list[typing.Type[CommandOperator]] = []
def operator_class( def operator_class(
name: str, name: str,
help: str = "", help: str = '',
usage: str = None, usage: str = None,
alias: list[str] = [], alias: list[str] = [],
privilege: int=1, # 1为普通用户,2为管理员 privilege: int = 1, # 1为普通用户,2为管理员
parent_class: typing.Type[CommandOperator] = None parent_class: typing.Type[CommandOperator] = None,
) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: ) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]:
"""命令类装饰器 """命令类装饰器
@@ -96,8 +96,7 @@ class CommandOperator(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""实现此方法以执行命令 """实现此方法以执行命令
+16 -19
View File
@@ -2,32 +2,25 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, cmdmgr, errors from .. import operator, entities, errors
@operator.operator_class( @operator.operator_class(name='cmd', help='显示命令列表', usage='!cmd\n!cmd <命令名称>')
name="cmd",
help='显示命令列表',
usage='!cmd\n!cmd <命令名称>'
)
class CmdOperator(operator.CommandOperator): class CmdOperator(operator.CommandOperator):
"""命令列表 """命令列表"""
"""
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行 """执行"""
"""
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
reply_str = "当前所有命令: \n\n" reply_str = '当前所有命令: \n\n'
for cmd in self.ap.cmd_mgr.cmd_list: for cmd in self.ap.cmd_mgr.cmd_list:
if cmd.parent_class is None: if cmd.parent_class is None:
reply_str += f"{cmd.name}: {cmd.help}\n" reply_str += f'{cmd.name}: {cmd.help}\n'
reply_str += "\n使用 !cmd <命令名称> 查看命令的详细帮助" reply_str += '\n使用 !cmd <命令名称> 查看命令的详细帮助'
yield entities.CommandReturn(text=reply_str.strip()) yield entities.CommandReturn(text=reply_str.strip())
@@ -37,14 +30,18 @@ class CmdOperator(operator.CommandOperator):
cmd = None cmd = None
for _cmd in self.ap.cmd_mgr.cmd_list: for _cmd in self.ap.cmd_mgr.cmd_list:
if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (_cmd.parent_class is None): if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (
_cmd.parent_class is None
):
cmd = _cmd cmd = _cmd
break break
if cmd is None: if cmd is None:
yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name)) yield entities.CommandReturn(
error=errors.CommandNotFoundError(cmd_name)
)
else: else:
reply_str = f"{cmd.name}: {cmd.help}\n\n" reply_str = f'{cmd.name}: {cmd.help}\n\n'
reply_str += f"使用方法: \n{cmd.usage}" reply_str += f'使用方法: \n{cmd.usage}'
yield entities.CommandReturn(text=reply_str.strip()) yield entities.CommandReturn(text=reply_str.strip())
+22 -24
View File
@@ -1,62 +1,60 @@
from __future__ import annotations from __future__ import annotations
import typing import typing
import datetime
from .. import operator, entities, cmdmgr, errors from .. import operator, entities, errors
@operator.operator_class( @operator.operator_class(
name="del", name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all'
help="删除当前会话的历史记录",
usage='!del <序号>\n!del all'
) )
class DelOperator(operator.CommandOperator): class DelOperator(operator.CommandOperator):
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations: if context.session.conversations:
delete_index = 0 delete_index = 0
if len(context.crt_params) > 0: if len(context.crt_params) > 0:
try: try:
delete_index = int(context.crt_params[0]) delete_index = int(context.crt_params[0])
except: except Exception:
yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数')) yield entities.CommandReturn(
error=errors.CommandOperationError('索引必须是整数')
)
return return
if delete_index < 0 or delete_index >= len(context.session.conversations): if delete_index < 0 or delete_index >= len(context.session.conversations):
yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围')) yield entities.CommandReturn(
error=errors.CommandOperationError('索引超出范围')
)
return return
# 倒序 # 倒序
to_delete_index = len(context.session.conversations)-1-delete_index to_delete_index = len(context.session.conversations) - 1 - delete_index
if context.session.conversations[to_delete_index] == context.session.using_conversation: if (
context.session.conversations[to_delete_index]
== context.session.using_conversation
):
context.session.using_conversation = None context.session.using_conversation = None
del context.session.conversations[to_delete_index] del context.session.conversations[to_delete_index]
yield entities.CommandReturn(text=f"已删除对话: {delete_index}") yield entities.CommandReturn(text=f'已删除对话: {delete_index}')
else: else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)
@operator.operator_class( @operator.operator_class(
name="all", name='all', help='删除此会话的所有历史记录', parent_class=DelOperator
help="删除此会话的所有历史记录",
parent_class=DelOperator
) )
class DelAllOperator(operator.CommandOperator): class DelAllOperator(operator.CommandOperator):
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
context.session.conversations = [] context.session.conversations = []
context.session.using_conversation = None context.session.using_conversation = None
yield entities.CommandReturn(text="已删除所有对话") yield entities.CommandReturn(text='已删除所有对话')
+4 -5
View File
@@ -1,16 +1,15 @@
from __future__ import annotations from __future__ import annotations
from typing import AsyncGenerator from typing import AsyncGenerator
from .. import operator, entities, cmdmgr from .. import operator, entities
from ...plugin import context as plugin_context
@operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func') @operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func')
class FuncOperator(operator.CommandOperator): class FuncOperator(operator.CommandOperator):
async def execute( async def execute(
self, context: entities.ExecuteContext self, context: entities.ExecuteContext
) -> AsyncGenerator[entities.CommandReturn, None]: ) -> AsyncGenerator[entities.CommandReturn, None]:
reply_str = "当前已启用的内容函数: \n\n" reply_str = '当前已启用的内容函数: \n\n'
index = 1 index = 1
@@ -19,7 +18,7 @@ class FuncOperator(operator.CommandOperator):
) )
for func in all_functions: for func in all_functions:
reply_str += "{}. {}:\n{}\n\n".format( reply_str += '{}. {}:\n{}\n\n'.format(
index, index,
func.name, func.name,
func.description, func.description,
+3 -9
View File
@@ -2,19 +2,13 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, cmdmgr, errors from .. import operator, entities
@operator.operator_class( @operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>')
name='help',
help='显示帮助',
usage='!help\n!help <命令名称>'
)
class HelpOperator(operator.CommandOperator): class HelpOperator(operator.CommandOperator):
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接:https://langbot.app' help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接:https://langbot.app'
+25 -18
View File
@@ -1,36 +1,43 @@
from __future__ import annotations from __future__ import annotations
import typing import typing
import datetime
from .. import operator, entities, cmdmgr, errors from .. import operator, entities, errors
@operator.operator_class( @operator.operator_class(name='last', help='切换到前一个对话', usage='!last')
name="last",
help="切换到前一个对话",
usage='!last'
)
class LastOperator(operator.CommandOperator): class LastOperator(operator.CommandOperator):
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations: if context.session.conversations:
# 找到当前会话的上一个会话 # 找到当前会话的上一个会话
for index in range(len(context.session.conversations)-1, -1, -1): for index in range(len(context.session.conversations) - 1, -1, -1):
if context.session.conversations[index] == context.session.using_conversation: if (
context.session.conversations[index]
== context.session.using_conversation
):
if index == 0: if index == 0:
yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了')) yield entities.CommandReturn(
error=errors.CommandOperationError('已经是第一个对话了')
)
return return
else: else:
context.session.using_conversation = context.session.conversations[index-1] context.session.using_conversation = (
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S") context.session.conversations[index - 1]
)
time_str = (
context.session.using_conversation.create_time.strftime(
'%Y-%m-%d %H:%M:%S'
)
)
yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}") yield entities.CommandReturn(
text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}'
)
return return
else: else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)
+13 -17
View File
@@ -1,30 +1,26 @@
from __future__ import annotations from __future__ import annotations
import typing import typing
import datetime
from .. import operator, entities, cmdmgr, errors from .. import operator, entities, errors
@operator.operator_class( @operator.operator_class(
name="list", name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>'
help="列出此会话中的所有历史对话",
usage='!list\n!list <页码>'
) )
class ListOperator(operator.CommandOperator): class ListOperator(operator.CommandOperator):
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
page = 0 page = 0
if len(context.crt_params) > 0: if len(context.crt_params) > 0:
try: try:
page = int(context.crt_params[0]-1) page = int(context.crt_params[0] - 1)
except: except Exception:
yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数')) yield entities.CommandReturn(
error=errors.CommandOperationError('页码应为整数')
)
return return
record_per_page = 10 record_per_page = 10
@@ -36,21 +32,21 @@ class ListOperator(operator.CommandOperator):
using_conv_index = 0 using_conv_index = 0
for conv in context.session.conversations[::-1]: for conv in context.session.conversations[::-1]:
time_str = conv.create_time.strftime("%Y-%m-%d %H:%M:%S") time_str = conv.create_time.strftime('%Y-%m-%d %H:%M:%S')
if conv == context.session.using_conversation: if conv == context.session.using_conversation:
using_conv_index = index using_conv_index = index
if index >= page * record_per_page and index < (page + 1) * record_per_page: if index >= page * record_per_page and index < (page + 1) * record_per_page:
content += f"{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else '无内容'}\n" content += f'{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else "无内容"}\n'
index += 1 index += 1
if content == '': if content == '':
content = '' content = ''
else: else:
if context.session.using_conversation is None: if context.session.using_conversation is None:
content += "\n当前处于新会话" content += '\n当前处于新会话'
else: else:
content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else '无内容'}" content += f'\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else "无内容"}'
yield entities.CommandReturn(text=f"{page + 1} 页 (时间倒序):\n{content}") yield entities.CommandReturn(text=f'{page + 1} 页 (时间倒序):\n{content}')
+36 -28
View File
@@ -2,42 +2,44 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, cmdmgr, errors from .. import operator, entities, errors
@operator.operator_class( @operator.operator_class(
name="model", name='model',
help='显示和切换模型列表', help='显示和切换模型列表',
usage='!model\n!model show <模型名>\n!model set <模型名>', usage='!model\n!model show <模型名>\n!model set <模型名>',
privilege=2 privilege=2,
) )
class ModelOperator(operator.CommandOperator): class ModelOperator(operator.CommandOperator):
"""Model命令""" """Model命令"""
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
content = '模型列表:\n' content = '模型列表:\n'
model_list = self.ap.model_mgr.model_list model_list = self.ap.model_mgr.model_list
for model in model_list: for model in model_list:
content += f"\n名称: {model.name}\n" content += f'\n名称: {model.name}\n'
content += f"请求器: {model.requester.name}\n" content += f'请求器: {model.requester.name}\n'
content += f"\n当前对话使用模型: {context.query.use_model.name}\n" content += f'\n当前对话使用模型: {context.query.use_model.name}\n'
content += f"新对话默认使用模型: {self.ap.provider_cfg.data.get('model')}\n" content += f'新对话默认使用模型: {self.ap.provider_cfg.data.get("model")}\n'
yield entities.CommandReturn(text=content.strip()) yield entities.CommandReturn(text=content.strip())
@operator.operator_class( @operator.operator_class(
name="show", name='show', help='显示模型详情', privilege=2, parent_class=ModelOperator
help='显示模型详情',
privilege=2,
parent_class=ModelOperator
) )
class ModelShowOperator(operator.CommandOperator): class ModelShowOperator(operator.CommandOperator):
"""Model Show命令""" """Model Show命令"""
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
model_name = context.crt_params[0] model_name = context.crt_params[0]
model = None model = None
@@ -47,29 +49,31 @@ class ModelShowOperator(operator.CommandOperator):
break break
if model is None: if model is None:
yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}")) yield entities.CommandReturn(
error=errors.CommandError(f'未找到模型 {model_name}')
)
else: else:
content = f"模型详情\n" content = '模型详情\n'
content += f"名称: {model.name}\n" content += f'名称: {model.name}\n'
if model.model_name is not None: if model.model_name is not None:
content += f"请求模型名称: {model.model_name}\n" content += f'请求模型名称: {model.model_name}\n'
content += f"请求器: {model.requester.name}\n" content += f'请求器: {model.requester.name}\n'
content += f"密钥组: {model.token_mgr.name}\n" content += f'密钥组: {model.token_mgr.name}\n'
content += f"支持视觉: {model.vision_supported}\n" content += f'支持视觉: {model.vision_supported}\n'
content += f"支持工具: {model.tool_call_supported}\n" content += f'支持工具: {model.tool_call_supported}\n'
yield entities.CommandReturn(text=content.strip()) yield entities.CommandReturn(text=content.strip())
@operator.operator_class( @operator.operator_class(
name="set", name='set', help='设置默认使用模型', privilege=2, parent_class=ModelOperator
help='设置默认使用模型',
privilege=2,
parent_class=ModelOperator
) )
class ModelSetOperator(operator.CommandOperator): class ModelSetOperator(operator.CommandOperator):
"""Model Set命令""" """Model Set命令"""
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
model_name = context.crt_params[0] model_name = context.crt_params[0]
model = None model = None
@@ -79,8 +83,12 @@ class ModelSetOperator(operator.CommandOperator):
break break
if model is None: if model is None:
yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}")) yield entities.CommandReturn(
error=errors.CommandError(f'未找到模型 {model_name}')
)
else: else:
self.ap.provider_cfg.data['model'] = model_name self.ap.provider_cfg.data['model'] = model_name
await self.ap.provider_cfg.dump_config() await self.ap.provider_cfg.dump_config()
yield entities.CommandReturn(text=f"已设置当前使用模型为 {model_name},重置会话以生效") yield entities.CommandReturn(
text=f'已设置当前使用模型为 {model_name},重置会话以生效'
)
+25 -18
View File
@@ -1,35 +1,42 @@
from __future__ import annotations from __future__ import annotations
import typing import typing
import datetime
from .. import operator, entities, cmdmgr, errors from .. import operator, entities, errors
@operator.operator_class( @operator.operator_class(name='next', help='切换到后一个对话', usage='!next')
name="next",
help="切换到后一个对话",
usage='!next'
)
class NextOperator(operator.CommandOperator): class NextOperator(operator.CommandOperator):
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations: if context.session.conversations:
# 找到当前会话的下一个会话 # 找到当前会话的下一个会话
for index in range(len(context.session.conversations)): for index in range(len(context.session.conversations)):
if context.session.conversations[index] == context.session.using_conversation: if (
if index == len(context.session.conversations)-1: context.session.conversations[index]
yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了')) == context.session.using_conversation
):
if index == len(context.session.conversations) - 1:
yield entities.CommandReturn(
error=errors.CommandOperationError('已经是最后一个对话了')
)
return return
else: else:
context.session.using_conversation = context.session.conversations[index+1] context.session.using_conversation = (
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S") context.session.conversations[index + 1]
)
time_str = (
context.session.using_conversation.create_time.strftime(
'%Y-%m-%d %H:%M:%S'
)
)
yield entities.CommandReturn(text=f"已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}") yield entities.CommandReturn(
text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}'
)
return return
else: else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)
+36 -35
View File
@@ -2,31 +2,32 @@ from __future__ import annotations
import json import json
import typing import typing
import traceback
import ollama import ollama
from .. import operator, entities, errors from .. import operator, entities, errors
@operator.operator_class( @operator.operator_class(
name="ollama", name='ollama',
help="ollama平台操作", help='ollama平台操作',
usage="!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>" usage='!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>',
) )
class OllamaOperator(operator.CommandOperator): class OllamaOperator(operator.CommandOperator):
async def execute( async def execute(
self, context: entities.ExecuteContext self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try: try:
content: str = '模型列表:\n' content: str = '模型列表:\n'
model_list: list = ollama.list().get('models', []) model_list: list = ollama.list().get('models', [])
for model in model_list: for model in model_list:
content += f"名称: {model['name']}\n" content += f'名称: {model["name"]}\n'
content += f"修改时间: {model['modified_at']}\n" content += f'修改时间: {model["modified_at"]}\n'
content += f"大小: {bytes_to_mb(model['size'])}MB\n\n" content += f'大小: {bytes_to_mb(model["size"])}MB\n\n'
yield entities.CommandReturn(text=f"{content.strip()}") yield entities.CommandReturn(text=f'{content.strip()}')
except ollama.ResponseError as e: except ollama.ResponseError:
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常")) yield entities.CommandReturn(
error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常')
)
def bytes_to_mb(num_bytes): def bytes_to_mb(num_bytes):
@@ -35,14 +36,11 @@ def bytes_to_mb(num_bytes):
@operator.operator_class( @operator.operator_class(
name="show", name='show', help='ollama模型详情', privilege=2, parent_class=OllamaOperator
help="ollama模型详情",
privilege=2,
parent_class=OllamaOperator
) )
class OllamaShowOperator(operator.CommandOperator): class OllamaShowOperator(operator.CommandOperator):
async def execute( async def execute(
self, context: entities.ExecuteContext self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
content: str = '模型详情:\n' content: str = '模型详情:\n'
try: try:
@@ -53,31 +51,36 @@ class OllamaShowOperator(operator.CommandOperator):
for key in ['license', 'modelfile']: for key in ['license', 'modelfile']:
show[key] = ignore_show show[key] = ignore_show
for key in ['tokenizer.chat_template.rag', 'tokenizer.chat_template.tool_use']: for key in [
'tokenizer.chat_template.rag',
'tokenizer.chat_template.tool_use',
]:
model_info[key] = ignore_show model_info[key] = ignore_show
content += json.dumps(show, indent=4) content += json.dumps(show, indent=4)
yield entities.CommandReturn(text=content.strip()) yield entities.CommandReturn(text=content.strip())
except ollama.ResponseError as e: except ollama.ResponseError:
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型详情,请确认 Ollama 服务正常")) yield entities.CommandReturn(
error=errors.CommandError('无法获取模型详情,请确认 Ollama 服务正常')
)
@operator.operator_class( @operator.operator_class(
name="pull", name='pull', help='ollama模型拉取', privilege=2, parent_class=OllamaOperator
help="ollama模型拉取",
privilege=2,
parent_class=OllamaOperator
) )
class OllamaPullOperator(operator.CommandOperator): class OllamaPullOperator(operator.CommandOperator):
async def execute( async def execute(
self, context: entities.ExecuteContext self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try: try:
model_list: list = ollama.list().get('models', []) model_list: list = ollama.list().get('models', [])
if context.crt_params[0] in [model['name'] for model in model_list]: if context.crt_params[0] in [model['name'] for model in model_list]:
yield entities.CommandReturn(text="模型已存在") yield entities.CommandReturn(text='模型已存在')
return return
except ollama.ResponseError as e: except ollama.ResponseError:
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常")) yield entities.CommandReturn(
error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常')
)
return return
on_progress: bool = False on_progress: bool = False
@@ -99,23 +102,21 @@ class OllamaPullOperator(operator.CommandOperator):
if percentage_completed > progress_count: if percentage_completed > progress_count:
progress_count += 10 progress_count += 10
yield entities.CommandReturn( yield entities.CommandReturn(
text=f"下载进度: {completed}/{total} ({percentage_completed:.2f}%)") text=f'下载进度: {completed}/{total} ({percentage_completed:.2f}%)'
)
except ollama.ResponseError as e: except ollama.ResponseError as e:
yield entities.CommandReturn(text=f"拉取失败: {e.error}") yield entities.CommandReturn(text=f'拉取失败: {e.error}')
@operator.operator_class( @operator.operator_class(
name="del", name='del', help='ollama模型删除', privilege=2, parent_class=OllamaOperator
help="ollama模型删除",
privilege=2,
parent_class=OllamaOperator
) )
class OllamaDelOperator(operator.CommandOperator): class OllamaDelOperator(operator.CommandOperator):
async def execute( async def execute(
self, context: entities.ExecuteContext self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try: try:
ret: str = ollama.delete(model=context.crt_params[0])['status'] ret: str = ollama.delete(model=context.crt_params[0])['status']
except ollama.ResponseError as e: except ollama.ResponseError as e:
ret = f"{e.error}" ret = f'{e.error}'
yield entities.CommandReturn(text=ret) yield entities.CommandReturn(text=ret)
+101 -94
View File
@@ -2,31 +2,30 @@ from __future__ import annotations
import typing import typing
import traceback import traceback
from .. import operator, entities, cmdmgr, errors from .. import operator, entities, errors
from ...core import app
@operator.operator_class( @operator.operator_class(
name="plugin", name='plugin',
help="插件操作", help='插件操作',
usage="!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>" usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>',
) )
class PluginOperator(operator.CommandOperator): class PluginOperator(operator.CommandOperator):
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
plugin_list = self.ap.plugin_mgr.plugins() plugin_list = self.ap.plugin_mgr.plugins()
reply_str = "所有插件({}):\n".format(len(plugin_list)) reply_str = '所有插件({}):\n'.format(len(plugin_list))
idx = 0 idx = 0
for plugin in plugin_list: for plugin in plugin_list:
reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\ reply_str += '\n#{} {} {}\n{}\nv{}\n作者: {}\n'.format(
.format((idx+1), plugin.plugin_name, (idx + 1),
"[已禁用]" if not plugin.enabled else "", plugin.plugin_name,
plugin.plugin_description, '[已禁用]' if not plugin.enabled else '',
plugin.plugin_version, plugin.plugin_author) plugin.plugin_description,
plugin.plugin_version,
plugin.plugin_author,
)
idx += 1 idx += 1
@@ -34,48 +33,42 @@ class PluginOperator(operator.CommandOperator):
@operator.operator_class( @operator.operator_class(
name="get", name='get', help='安装插件', privilege=2, parent_class=PluginOperator
help="安装插件",
privilege=2,
parent_class=PluginOperator
) )
class PluginGetOperator(operator.CommandOperator): class PluginGetOperator(operator.CommandOperator):
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址')) yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件仓库地址')
)
else: else:
repo = context.crt_params[0] repo = context.crt_params[0]
yield entities.CommandReturn(text="正在安装插件...") yield entities.CommandReturn(text='正在安装插件...')
try: try:
await self.ap.plugin_mgr.install_plugin(repo) await self.ap.plugin_mgr.install_plugin(repo)
yield entities.CommandReturn(text="插件安装成功,请重启程序以加载插件") yield entities.CommandReturn(text='插件安装成功,请重启程序以加载插件')
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件安装失败: "+str(e))) yield entities.CommandReturn(
error=errors.CommandError('插件安装失败: ' + str(e))
)
@operator.operator_class( @operator.operator_class(
name="update", name='update', help='更新插件', privilege=2, parent_class=PluginOperator
help="更新插件",
privilege=2,
parent_class=PluginOperator
) )
class PluginUpdateOperator(operator.CommandOperator): class PluginUpdateOperator(operator.CommandOperator):
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
else: else:
plugin_name = context.crt_params[0] plugin_name = context.crt_params[0]
@@ -83,36 +76,34 @@ class PluginUpdateOperator(operator.CommandOperator):
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is not None: if plugin_container is not None:
yield entities.CommandReturn(text="正在更新插件...") yield entities.CommandReturn(text='正在更新插件...')
await self.ap.plugin_mgr.update_plugin(plugin_name) await self.ap.plugin_mgr.update_plugin(plugin_name)
yield entities.CommandReturn(text="插件更新成功,请重启程序以加载插件") yield entities.CommandReturn(
text='插件更新成功,请重启程序以加载插件'
)
else: else:
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: 未找到插件")) yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: 未找到插件')
)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e))) yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: ' + str(e))
)
@operator.operator_class( @operator.operator_class(
name="all", name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator
help="更新所有插件",
privilege=2,
parent_class=PluginUpdateOperator
) )
class PluginUpdateAllOperator(operator.CommandOperator): class PluginUpdateAllOperator(operator.CommandOperator):
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try: try:
plugins = [ plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()]
p.plugin_name
for p in self.ap.plugin_mgr.plugins()
]
if plugins: if plugins:
yield entities.CommandReturn(text="正在更新插件...") yield entities.CommandReturn(text='正在更新插件...')
updated = [] updated = []
try: try:
for plugin_name in plugins: for plugin_name in plugins:
@@ -120,30 +111,32 @@ class PluginUpdateAllOperator(operator.CommandOperator):
updated.append(plugin_name) updated.append(plugin_name)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e))) yield entities.CommandReturn(
yield entities.CommandReturn(text="已更新插件: {}".format(", ".join(updated))) error=errors.CommandError('插件更新失败: ' + str(e))
)
yield entities.CommandReturn(
text='已更新插件: {}'.format(', '.join(updated))
)
else: else:
yield entities.CommandReturn(text="没有可更新的插件") yield entities.CommandReturn(text='没有可更新的插件')
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e))) yield entities.CommandReturn(
error=errors.CommandError('插件更新失败: ' + str(e))
)
@operator.operator_class( @operator.operator_class(
name="del", name='del', help='删除插件', privilege=2, parent_class=PluginOperator
help="删除插件",
privilege=2,
parent_class=PluginOperator
) )
class PluginDelOperator(operator.CommandOperator): class PluginDelOperator(operator.CommandOperator):
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
else: else:
plugin_name = context.crt_params[0] plugin_name = context.crt_params[0]
@@ -151,67 +144,81 @@ class PluginDelOperator(operator.CommandOperator):
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is not None: if plugin_container is not None:
yield entities.CommandReturn(text="正在删除插件...") yield entities.CommandReturn(text='正在删除插件...')
await self.ap.plugin_mgr.uninstall_plugin(plugin_name) await self.ap.plugin_mgr.uninstall_plugin(plugin_name)
yield entities.CommandReturn(text="插件删除成功,请重启程序以加载插件") yield entities.CommandReturn(
text='插件删除成功,请重启程序以加载插件'
)
else: else:
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: 未找到插件")) yield entities.CommandReturn(
error=errors.CommandError('插件删除失败: 未找到插件')
)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: "+str(e))) yield entities.CommandReturn(
error=errors.CommandError('插件删除失败: ' + str(e))
)
@operator.operator_class( @operator.operator_class(
name="on", name='on', help='启用插件', privilege=2, parent_class=PluginOperator
help="启用插件",
privilege=2,
parent_class=PluginOperator
) )
class PluginEnableOperator(operator.CommandOperator): class PluginEnableOperator(operator.CommandOperator):
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
else: else:
plugin_name = context.crt_params[0] plugin_name = context.crt_params[0]
try: try:
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True): if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True):
yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name)) yield entities.CommandReturn(
text='已启用插件: {}'.format(plugin_name)
)
else: else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name))) yield entities.CommandReturn(
error=errors.CommandError(
'插件状态修改失败: 未找到插件 {}'.format(plugin_name)
)
)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e))) yield entities.CommandReturn(
error=errors.CommandError('插件状态修改失败: ' + str(e))
)
@operator.operator_class( @operator.operator_class(
name="off", name='off', help='禁用插件', privilege=2, parent_class=PluginOperator
help="禁用插件",
privilege=2,
parent_class=PluginOperator
) )
class PluginDisableOperator(operator.CommandOperator): class PluginDisableOperator(operator.CommandOperator):
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) yield entities.CommandReturn(
error=errors.ParamNotEnoughError('请提供插件名称')
)
else: else:
plugin_name = context.crt_params[0] plugin_name = context.crt_params[0]
try: try:
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False): if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False):
yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name)) yield entities.CommandReturn(
text='已禁用插件: {}'.format(plugin_name)
)
else: else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name))) yield entities.CommandReturn(
error=errors.CommandError(
'插件状态修改失败: 未找到插件 {}'.format(plugin_name)
)
)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e))) yield entities.CommandReturn(
error=errors.CommandError('插件状态修改失败: ' + str(e))
)
+8 -13
View File
@@ -2,28 +2,23 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, cmdmgr, errors from .. import operator, entities, errors
@operator.operator_class( @operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt')
name="prompt",
help="查看当前对话的前文",
usage='!prompt'
)
class PromptOperator(operator.CommandOperator): class PromptOperator(operator.CommandOperator):
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行 """执行"""
"""
if context.session.using_conversation is None: if context.session.using_conversation is None:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) yield entities.CommandReturn(
error=errors.CommandOperationError('当前没有对话')
)
else: else:
reply_str = '当前对话所有内容:\n\n' reply_str = '当前对话所有内容:\n\n'
for msg in context.session.using_conversation.messages: for msg in context.session.using_conversation.messages:
reply_str += f"{msg.role}: {msg.content}\n" reply_str += f'{msg.role}: {msg.content}\n'
yield entities.CommandReturn(text=reply_str) yield entities.CommandReturn(text=reply_str)
+5 -9
View File
@@ -2,23 +2,19 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, cmdmgr, errors from .. import operator, entities, errors
@operator.operator_class( @operator.operator_class(
name="resend", name='resend', help='重发当前会话的最后一条消息', usage='!resend'
help="重发当前会话的最后一条消息",
usage='!resend'
) )
class ResendOperator(operator.CommandOperator): class ResendOperator(operator.CommandOperator):
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
# 回滚到最后一条用户message前 # 回滚到最后一条用户message前
if context.session.using_conversation is None: if context.session.using_conversation is None:
yield entities.CommandReturn(error=errors.CommandError("当前没有对话")) yield entities.CommandReturn(error=errors.CommandError('当前没有对话'))
else: else:
conv_msg = context.session.using_conversation.messages conv_msg = context.session.using_conversation.messages
@@ -31,4 +27,4 @@ class ResendOperator(operator.CommandOperator):
conv_msg.pop() conv_msg.pop()
# 不重发了,提示用户已删除就行了 # 不重发了,提示用户已删除就行了
yield entities.CommandReturn(text="已删除最后一次请求记录") yield entities.CommandReturn(text='已删除最后一次请求记录')
+5 -12
View File
@@ -2,22 +2,15 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, cmdmgr, errors from .. import operator, entities
@operator.operator_class( @operator.operator_class(name='reset', help='重置当前会话', usage='!reset')
name="reset",
help="重置当前会话",
usage='!reset'
)
class ResetOperator(operator.CommandOperator): class ResetOperator(operator.CommandOperator):
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行 """执行"""
"""
context.session.using_conversation = None context.session.using_conversation = None
yield entities.CommandReturn(text="已重置当前会话") yield entities.CommandReturn(text='已重置当前会话')
+9 -15
View File
@@ -3,28 +3,22 @@ from __future__ import annotations
import typing import typing
import traceback import traceback
from .. import operator, entities, cmdmgr, errors from .. import operator, entities, errors
@operator.operator_class( @operator.operator_class(name='update', help='更新程序', usage='!update', privilege=2)
name="update",
help="更新程序",
usage='!update',
privilege=2
)
class UpdateCommand(operator.CommandOperator): class UpdateCommand(operator.CommandOperator):
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try: try:
yield entities.CommandReturn(text="正在进行更新...") yield entities.CommandReturn(text='正在进行更新...')
if await self.ap.ver_mgr.update_all(): if await self.ap.ver_mgr.update_all():
yield entities.CommandReturn(text="更新完成,请重启程序以应用更新") yield entities.CommandReturn(text='更新完成,请重启程序以应用更新')
else: else:
yield entities.CommandReturn(text="当前已是最新版本") yield entities.CommandReturn(text='当前已是最新版本')
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("更新失败: "+str(e))) yield entities.CommandReturn(
error=errors.CommandError('更新失败: ' + str(e))
)
+6 -12
View File
@@ -2,26 +2,20 @@ from __future__ import annotations
import typing import typing
from .. import operator, cmdmgr, entities, errors from .. import operator, entities
@operator.operator_class( @operator.operator_class(name='version', help='显示版本信息', usage='!version')
name="version",
help="显示版本信息",
usage='!version'
)
class VersionCommand(operator.CommandOperator): class VersionCommand(operator.CommandOperator):
async def execute( async def execute(
self, self, context: entities.ExecuteContext
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
reply_str = f"当前版本: \n{self.ap.ver_mgr.get_current_version()}" reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}'
try: try:
if await self.ap.ver_mgr.is_new_version_available(): if await self.ap.ver_mgr.is_new_version_available():
reply_str += "\n\n有新版本可用。" reply_str += '\n\n有新版本可用。'
except: except Exception:
pass pass
yield entities.CommandReturn(text=reply_str.strip()) yield entities.CommandReturn(text=reply_str.strip())
+12 -11
View File
@@ -9,7 +9,10 @@ class JSONConfigFile(file_model.ConfigFile):
"""JSON配置文件""" """JSON配置文件"""
def __init__( def __init__(
self, config_file_name: str, template_file_name: str = None, template_data: dict = None self,
config_file_name: str,
template_file_name: str = None,
template_data: dict = None,
) -> None: ) -> None:
self.config_file_name = config_file_name self.config_file_name = config_file_name
self.template_file_name = template_file_name self.template_file_name = template_file_name
@@ -22,28 +25,26 @@ class JSONConfigFile(file_model.ConfigFile):
if self.template_file_name is not None: if self.template_file_name is not None:
shutil.copyfile(self.template_file_name, self.config_file_name) shutil.copyfile(self.template_file_name, self.config_file_name)
elif self.template_data is not None: elif self.template_data is not None:
with open(self.config_file_name, "w", encoding="utf-8") as f: with open(self.config_file_name, 'w', encoding='utf-8') as f:
json.dump(self.template_data, f, indent=4, ensure_ascii=False) json.dump(self.template_data, f, indent=4, ensure_ascii=False)
else: else:
raise ValueError("template_file_name or template_data must be provided") raise ValueError('template_file_name or template_data must be provided')
async def load(self, completion: bool=True) -> dict:
async def load(self, completion: bool = True) -> dict:
if not self.exists(): if not self.exists():
await self.create() await self.create()
if self.template_file_name is not None: if self.template_file_name is not None:
with open(self.template_file_name, "r", encoding="utf-8") as f: with open(self.template_file_name, 'r', encoding='utf-8') as f:
self.template_data = json.load(f) self.template_data = json.load(f)
with open(self.config_file_name, "r", encoding="utf-8") as f: with open(self.config_file_name, 'r', encoding='utf-8') as f:
try: try:
cfg = json.load(f) cfg = json.load(f)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise Exception(f"配置文件 {self.config_file_name} 语法错误: {e}") raise Exception(f'配置文件 {self.config_file_name} 语法错误: {e}')
if completion: if completion:
for key in self.template_data: for key in self.template_data:
if key not in cfg: if key not in cfg:
cfg[key] = self.template_data[key] cfg[key] = self.template_data[key]
@@ -51,9 +52,9 @@ class JSONConfigFile(file_model.ConfigFile):
return cfg return cfg
async def save(self, cfg: dict): async def save(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f: with open(self.config_file_name, 'w', encoding='utf-8') as f:
json.dump(cfg, f, indent=4, ensure_ascii=False) json.dump(cfg, f, indent=4, ensure_ascii=False)
def save_sync(self, cfg: dict): def save_sync(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f: with open(self.config_file_name, 'w', encoding='utf-8') as f:
json.dump(cfg, f, indent=4, ensure_ascii=False) json.dump(cfg, f, indent=4, ensure_ascii=False)
+1 -1
View File
@@ -25,7 +25,7 @@ class PythonModuleConfigFile(file_model.ConfigFile):
async def create(self): async def create(self):
shutil.copyfile(self.template_file_name, self.config_file_name) shutil.copyfile(self.template_file_name, self.config_file_name)
async def load(self, completion: bool=True) -> dict: async def load(self, completion: bool = True) -> dict:
module_name = os.path.splitext(os.path.basename(self.config_file_name))[0] module_name = os.path.splitext(os.path.basename(self.config_file_name))[0]
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
+12 -11
View File
@@ -9,7 +9,10 @@ class YAMLConfigFile(file_model.ConfigFile):
"""YAML配置文件""" """YAML配置文件"""
def __init__( def __init__(
self, config_file_name: str, template_file_name: str = None, template_data: dict = None self,
config_file_name: str,
template_file_name: str = None,
template_data: dict = None,
) -> None: ) -> None:
self.config_file_name = config_file_name self.config_file_name = config_file_name
self.template_file_name = template_file_name self.template_file_name = template_file_name
@@ -22,28 +25,26 @@ class YAMLConfigFile(file_model.ConfigFile):
if self.template_file_name is not None: if self.template_file_name is not None:
shutil.copyfile(self.template_file_name, self.config_file_name) shutil.copyfile(self.template_file_name, self.config_file_name)
elif self.template_data is not None: elif self.template_data is not None:
with open(self.config_file_name, "w", encoding="utf-8") as f: with open(self.config_file_name, 'w', encoding='utf-8') as f:
yaml.dump(self.template_data, f, indent=4, allow_unicode=True) yaml.dump(self.template_data, f, indent=4, allow_unicode=True)
else: else:
raise ValueError("template_file_name or template_data must be provided") raise ValueError('template_file_name or template_data must be provided')
async def load(self, completion: bool=True) -> dict:
async def load(self, completion: bool = True) -> dict:
if not self.exists(): if not self.exists():
await self.create() await self.create()
if self.template_file_name is not None: if self.template_file_name is not None:
with open(self.template_file_name, "r", encoding="utf-8") as f: with open(self.template_file_name, 'r', encoding='utf-8') as f:
self.template_data = yaml.load(f, Loader=yaml.FullLoader) self.template_data = yaml.load(f, Loader=yaml.FullLoader)
with open(self.config_file_name, "r", encoding="utf-8") as f: with open(self.config_file_name, 'r', encoding='utf-8') as f:
try: try:
cfg = yaml.load(f, Loader=yaml.FullLoader) cfg = yaml.load(f, Loader=yaml.FullLoader)
except yaml.YAMLError as e: except yaml.YAMLError as e:
raise Exception(f"配置文件 {self.config_file_name} 语法错误: {e}") raise Exception(f'配置文件 {self.config_file_name} 语法错误: {e}')
if completion: if completion:
for key in self.template_data: for key in self.template_data:
if key not in cfg: if key not in cfg:
cfg[key] = self.template_data[key] cfg[key] = self.template_data[key]
@@ -51,9 +52,9 @@ class YAMLConfigFile(file_model.ConfigFile):
return cfg return cfg
async def save(self, cfg: dict): async def save(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f: with open(self.config_file_name, 'w', encoding='utf-8') as f:
yaml.dump(cfg, f, indent=4, allow_unicode=True) yaml.dump(cfg, f, indent=4, allow_unicode=True)
def save_sync(self, cfg: dict): def save_sync(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f: with open(self.config_file_name, 'w', encoding='utf-8') as f:
yaml.dump(cfg, f, indent=4, allow_unicode=True) yaml.dump(cfg, f, indent=4, allow_unicode=True)
+19 -18
View File
@@ -31,7 +31,7 @@ class ConfigManager:
self.file = cfg_file self.file = cfg_file
self.data = {} self.data = {}
async def load_config(self, completion: bool=True): async def load_config(self, completion: bool = True):
self.data = await self.file.load(completion=completion) self.data = await self.file.load(completion=completion)
async def dump_config(self): async def dump_config(self):
@@ -41,7 +41,9 @@ class ConfigManager:
self.file.save_sync(self.data) self.file.save_sync(self.data)
async def load_python_module_config(config_name: str, template_name: str, completion: bool=True) -> ConfigManager: async def load_python_module_config(
config_name: str, template_name: str, completion: bool = True
) -> ConfigManager:
"""加载Python模块配置文件 """加载Python模块配置文件
Args: Args:
@@ -52,10 +54,7 @@ async def load_python_module_config(config_name: str, template_name: str, comple
Returns: Returns:
ConfigManager: 配置文件管理器 ConfigManager: 配置文件管理器
""" """
cfg_inst = pymodule.PythonModuleConfigFile( cfg_inst = pymodule.PythonModuleConfigFile(config_name, template_name)
config_name,
template_name
)
cfg_mgr = ConfigManager(cfg_inst) cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config(completion=completion) await cfg_mgr.load_config(completion=completion)
@@ -63,7 +62,12 @@ async def load_python_module_config(config_name: str, template_name: str, comple
return cfg_mgr return cfg_mgr
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager: async def load_json_config(
config_name: str,
template_name: str = None,
template_data: dict = None,
completion: bool = True,
) -> ConfigManager:
"""加载JSON配置文件 """加载JSON配置文件
Args: Args:
@@ -72,11 +76,7 @@ async def load_json_config(config_name: str, template_name: str=None, template_d
template_data (dict): 模板数据 template_data (dict): 模板数据
completion (bool): 是否自动补全内存中的配置文件 completion (bool): 是否自动补全内存中的配置文件
""" """
cfg_inst = json_file.JSONConfigFile( cfg_inst = json_file.JSONConfigFile(config_name, template_name, template_data)
config_name,
template_name,
template_data
)
cfg_mgr = ConfigManager(cfg_inst) cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config(completion=completion) await cfg_mgr.load_config(completion=completion)
@@ -84,7 +84,12 @@ async def load_json_config(config_name: str, template_name: str=None, template_d
return cfg_mgr return cfg_mgr
async def load_yaml_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager: async def load_yaml_config(
config_name: str,
template_name: str = None,
template_data: dict = None,
completion: bool = True,
) -> ConfigManager:
"""加载YAML配置文件 """加载YAML配置文件
Args: Args:
@@ -96,11 +101,7 @@ async def load_yaml_config(config_name: str, template_name: str=None, template_d
Returns: Returns:
ConfigManager: 配置文件管理器 ConfigManager: 配置文件管理器
""" """
cfg_inst = yaml_file.YAMLConfigFile( cfg_inst = yaml_file.YAMLConfigFile(config_name, template_name, template_data)
config_name,
template_name,
template_data
)
cfg_mgr = ConfigManager(cfg_inst) cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config(completion=completion) await cfg_mgr.load_config(completion=completion)
+1 -1
View File
@@ -22,7 +22,7 @@ class ConfigFile(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def load(self, completion: bool=True) -> dict: async def load(self, completion: bool = True) -> dict:
pass pass
@abc.abstractmethod @abc.abstractmethod
+44 -18
View File
@@ -2,9 +2,7 @@ from __future__ import annotations
import logging import logging
import asyncio import asyncio
import threading
import traceback import traceback
import enum
import sys import sys
import os import os
@@ -29,7 +27,6 @@ from ..discover import engine as discover_engine
from ..utils import logcache, ip from ..utils import logcache, ip
from . import taskmgr from . import taskmgr
from . import entities as core_entities from . import entities as core_entities
from .bootutils import config
class Application: class Application:
@@ -123,33 +120,55 @@ class Application:
async def run(self): async def run(self):
try: try:
await self.plugin_mgr.initialize_plugins() await self.plugin_mgr.initialize_plugins()
# 后续可能会允许动态重启其他任务 # 后续可能会允许动态重启其他任务
# 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程 # 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程
async def never_ending(): async def never_ending():
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM]) self.task_mgr.create_task(
self.task_mgr.create_task(self.ctrl.run(), name="query-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION]) self.platform_mgr.run(),
self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION]) name='platform-manager',
self.task_mgr.create_task(never_ending(), name="never-ending-task", scopes=[core_entities.LifecycleControlScope.APPLICATION]) scopes=[
core_entities.LifecycleControlScope.APPLICATION,
core_entities.LifecycleControlScope.PLATFORM,
],
)
self.task_mgr.create_task(
self.ctrl.run(),
name='query-controller',
scopes=[core_entities.LifecycleControlScope.APPLICATION],
)
self.task_mgr.create_task(
self.http_ctrl.run(),
name='http-api-controller',
scopes=[core_entities.LifecycleControlScope.APPLICATION],
)
self.task_mgr.create_task(
never_ending(),
name='never-ending-task',
scopes=[core_entities.LifecycleControlScope.APPLICATION],
)
await self.print_web_access_info() await self.print_web_access_info()
await self.task_mgr.wait_all() await self.task_mgr.wait_all()
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except Exception as e: except Exception as e:
self.logger.error(f"应用运行致命异常: {e}") self.logger.error(f'应用运行致命异常: {e}')
self.logger.debug(f"Traceback: {traceback.format_exc()}") self.logger.debug(f'Traceback: {traceback.format_exc()}')
async def print_web_access_info(self): async def print_web_access_info(self):
"""打印访问 webui 的提示""" """打印访问 webui 的提示"""
if not os.path.exists(os.path.join(".", "web/out")): if not os.path.exists(os.path.join('.', 'web/out')):
self.logger.warning("WebUI 文件缺失,请根据文档获取:https://docs.langbot.app/webui/intro.html") self.logger.warning(
'WebUI 文件缺失,请根据文档获取:https://docs.langbot.app/webui/intro.html'
)
return return
host_ip = "127.0.0.1" host_ip = '127.0.0.1'
public_ip = await ip.get_myip() public_ip = await ip.get_myip()
@@ -170,7 +189,7 @@ class Application:
🤯 WebUI 仍处于 Beta 测试阶段如有问题或建议请反馈到 https://github.com/RockChinQ/LangBot/issues 🤯 WebUI 仍处于 Beta 测试阶段如有问题或建议请反馈到 https://github.com/RockChinQ/LangBot/issues
======================================= =======================================
""".strip() """.strip()
for line in tips.split("\n"): for line in tips.split('\n'):
self.logger.info(line) self.logger.info(line)
async def reload( async def reload(
@@ -179,21 +198,28 @@ class Application:
): ):
match scope: match scope:
case core_entities.LifecycleControlScope.PLATFORM.value: case core_entities.LifecycleControlScope.PLATFORM.value:
self.logger.info("执行热重载 scope="+scope) self.logger.info('执行热重载 scope=' + scope)
await self.platform_mgr.shutdown() await self.platform_mgr.shutdown()
self.platform_mgr = im_mgr.PlatformManager(self) self.platform_mgr = im_mgr.PlatformManager(self)
await self.platform_mgr.initialize() await self.platform_mgr.initialize()
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM]) self.task_mgr.create_task(
self.platform_mgr.run(),
name='platform-manager',
scopes=[
core_entities.LifecycleControlScope.APPLICATION,
core_entities.LifecycleControlScope.PLATFORM,
],
)
case core_entities.LifecycleControlScope.PLUGIN.value: case core_entities.LifecycleControlScope.PLUGIN.value:
self.logger.info("执行热重载 scope="+scope) self.logger.info('执行热重载 scope=' + scope)
await self.plugin_mgr.destroy_plugins() await self.plugin_mgr.destroy_plugins()
# 删除 sys.module 中所有的 plugins/* 下的模块 # 删除 sys.module 中所有的 plugins/* 下的模块
for mod in list(sys.modules.keys()): for mod in list(sys.modules.keys()):
if mod.startswith("plugins."): if mod.startswith('plugins.'):
del sys.modules[mod] del sys.modules[mod]
self.plugin_mgr = plugin_mgr.PluginManager(self) self.plugin_mgr = plugin_mgr.PluginManager(self)
@@ -204,7 +230,7 @@ class Application:
await self.plugin_mgr.load_plugins() await self.plugin_mgr.load_plugins()
await self.plugin_mgr.initialize_plugins() await self.plugin_mgr.initialize_plugins()
case core_entities.LifecycleControlScope.PROVIDER.value: case core_entities.LifecycleControlScope.PROVIDER.value:
self.logger.info("执行热重载 scope="+scope) self.logger.info('执行热重载 scope=' + scope)
await self.tool_mgr.shutdown() await self.tool_mgr.shutdown()
+13 -16
View File
@@ -7,29 +7,30 @@ import os
from . import app from . import app
from ..audit import identifier from ..audit import identifier
from . import stage from . import stage
from ..utils import constants from ..utils import constants, importutil
# 引入启动阶段实现以便注册 # 引入启动阶段实现以便注册
from .stages import load_config, setup_logger, build_app, migrate, show_notes, genkeys from . import stages
importutil.import_modules_in_pkg(stages)
stage_order = [ stage_order = [
"LoadConfigStage", 'LoadConfigStage',
"MigrationStage", 'MigrationStage',
"GenKeysStage", 'GenKeysStage',
"SetupLoggerStage", 'SetupLoggerStage',
"BuildAppStage", 'BuildAppStage',
"ShowNotesStage" 'ShowNotesStage',
] ]
async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application: async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application:
# 生成标识符 # 生成标识符
identifier.init() identifier.init()
# 确定是否为调试模式 # 确定是否为调试模式
if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]: if 'DEBUG' in os.environ and os.environ['DEBUG'] in ['true', '1']:
constants.debug_mode = True constants.debug_mode = True
ap = app.Application() ap = app.Application()
@@ -50,21 +51,17 @@ async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application:
async def main(loop: asyncio.AbstractEventLoop): async def main(loop: asyncio.AbstractEventLoop):
try: try:
# 挂系统信号处理 # 挂系统信号处理
import signal import signal
ap: app.Application
def signal_handler(sig, frame): def signal_handler(sig, frame):
print("[Signal] 程序退出.") print('[Signal] 程序退出.')
# ap.shutdown() # ap.shutdown()
os._exit(0) os._exit(0)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
app_inst = await make_app(loop) app_inst = await make_app(loop)
ap = app_inst
await app_inst.run() await app_inst.run()
except Exception as e: except Exception:
traceback.print_exc() traceback.print_exc()
-2
View File
@@ -1,9 +1,7 @@
from __future__ import annotations from __future__ import annotations
import json
from ...config import manager as config_mgr from ...config import manager as config_mgr
from ...config.impls import pymodule
load_python_module_config = config_mgr.load_python_module_config load_python_module_config = config_mgr.load_python_module_config
+43 -38
View File
@@ -5,39 +5,39 @@ from ...utils import pkgmgr
# 检查依赖,防止用户未安装 # 检查依赖,防止用户未安装
# 左边为引入名称,右边为依赖名称 # 左边为引入名称,右边为依赖名称
required_deps = { required_deps = {
"requests": "requests", 'requests': 'requests',
"openai": "openai", 'openai': 'openai',
"anthropic": "anthropic", 'anthropic': 'anthropic',
"colorlog": "colorlog", 'colorlog': 'colorlog',
"aiocqhttp": "aiocqhttp", 'aiocqhttp': 'aiocqhttp',
"botpy": "qq-botpy-rc", 'botpy': 'qq-botpy-rc',
"PIL": "pillow", 'PIL': 'pillow',
"nakuru": "nakuru-project-idk", 'nakuru': 'nakuru-project-idk',
"tiktoken": "tiktoken", 'tiktoken': 'tiktoken',
"yaml": "pyyaml", 'yaml': 'pyyaml',
"aiohttp": "aiohttp", 'aiohttp': 'aiohttp',
"psutil": "psutil", 'psutil': 'psutil',
"async_lru": "async-lru", 'async_lru': 'async-lru',
"ollama": "ollama", 'ollama': 'ollama',
"quart": "quart", 'quart': 'quart',
"quart_cors": "quart-cors", 'quart_cors': 'quart-cors',
"sqlalchemy": "sqlalchemy[asyncio]", 'sqlalchemy': 'sqlalchemy[asyncio]',
"aiosqlite": "aiosqlite", 'aiosqlite': 'aiosqlite',
"aiofiles": "aiofiles", 'aiofiles': 'aiofiles',
"aioshutil": "aioshutil", 'aioshutil': 'aioshutil',
"argon2": "argon2-cffi", 'argon2': 'argon2-cffi',
"jwt": "pyjwt", 'jwt': 'pyjwt',
"Crypto": "pycryptodome", 'Crypto': 'pycryptodome',
"lark_oapi": "lark-oapi", 'lark_oapi': 'lark-oapi',
"discord": "discord.py", 'discord': 'discord.py',
"cryptography": "cryptography", 'cryptography': 'cryptography',
"gewechat_client": "gewechat-client", 'gewechat_client': 'gewechat-client',
"dingtalk_stream": "dingtalk_stream", 'dingtalk_stream': 'dingtalk_stream',
"dashscope": "dashscope", 'dashscope': 'dashscope',
"telegram": "python-telegram-bot", 'telegram': 'python-telegram-bot',
"certifi": "certifi", 'certifi': 'certifi',
"mcp": "mcp", 'mcp': 'mcp',
"sqlmodel": "sqlmodel", 'sqlmodel': 'sqlmodel',
} }
@@ -52,20 +52,25 @@ async def check_deps() -> list[str]:
missing_deps.append(dep) missing_deps.append(dep)
return missing_deps return missing_deps
async def install_deps(deps: list[str]): async def install_deps(deps: list[str]):
global required_deps global required_deps
for dep in deps: for dep in deps:
pip.main(["install", required_deps[dep]]) pip.main(['install', required_deps[dep]])
async def precheck_plugin_deps(): async def precheck_plugin_deps():
print('[Startup] Prechecking plugin dependencies...') print('[Startup] Prechecking plugin dependencies...')
# 只有在plugins目录存在时才执行插件依赖安装 # 只有在plugins目录存在时才执行插件依赖安装
if os.path.exists("plugins"): if os.path.exists('plugins'):
for dir in os.listdir("plugins"): for dir in os.listdir('plugins'):
subdir = os.path.join("plugins", dir) subdir = os.path.join('plugins', dir)
if not os.path.isdir(subdir): if not os.path.isdir(subdir):
continue continue
if 'requirements.txt' in os.listdir(subdir): if 'requirements.txt' in os.listdir(subdir):
pkgmgr.install_requirements(os.path.join(subdir, 'requirements.txt'), extra_params=['-q', '-q', '-q']) pkgmgr.install_requirements(
os.path.join(subdir, 'requirements.txt'),
extra_params=['-q', '-q', '-q'],
)
+9 -9
View File
@@ -2,23 +2,23 @@ from __future__ import annotations
import os import os
import shutil import shutil
import sys
required_files = { required_files = {
"plugins/__init__.py": "templates/__init__.py", 'plugins/__init__.py': 'templates/__init__.py',
"data/config.yaml": "templates/config.yaml", 'data/config.yaml': 'templates/config.yaml',
} }
required_paths = [ required_paths = [
"temp", 'temp',
"data", 'data',
"data/metadata", 'data/metadata',
"data/logs", 'data/logs',
"data/labels", 'data/labels',
"plugins" 'plugins',
] ]
async def generate_files() -> list[str]: async def generate_files() -> list[str]:
global required_files, required_paths global required_files, required_paths
+20 -16
View File
@@ -1,5 +1,4 @@
import logging import logging
import os
import sys import sys
import time import time
@@ -9,11 +8,11 @@ from ...utils import constants
log_colors_config = { log_colors_config = {
"DEBUG": "green", # cyan white 'DEBUG': 'green', # cyan white
"INFO": "white", 'INFO': 'white',
"WARNING": "yellow", 'WARNING': 'yellow',
"ERROR": "red", 'ERROR': 'red',
"CRITICAL": "cyan", 'CRITICAL': 'cyan',
} }
@@ -27,26 +26,31 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging.
if constants.debug_mode: if constants.debug_mode:
level = logging.DEBUG level = logging.DEBUG
log_file_name = "data/logs/langbot-%s.log" % time.strftime( log_file_name = 'data/logs/langbot-%s.log' % time.strftime(
"%Y-%m-%d", time.localtime() '%Y-%m-%d', time.localtime()
) )
qcg_logger = logging.getLogger("langbot") qcg_logger = logging.getLogger('langbot')
qcg_logger.setLevel(level) qcg_logger.setLevel(level)
color_formatter = colorlog.ColoredFormatter( color_formatter = colorlog.ColoredFormatter(
fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : %(message)s", fmt='%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : %(message)s',
datefmt="%m-%d %H:%M:%S", datefmt='%m-%d %H:%M:%S',
log_colors=log_colors_config, log_colors=log_colors_config,
) )
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
# stream_handler.setLevel(level) # stream_handler.setLevel(level)
# stream_handler.setFormatter(color_formatter) # stream_handler.setFormatter(color_formatter)
stream_handler.stream = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1) stream_handler.stream = open(
sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1
)
log_handlers: list[logging.Handler] = [stream_handler, logging.FileHandler(log_file_name, encoding='utf-8')] log_handlers: list[logging.Handler] = [
stream_handler,
logging.FileHandler(log_file_name, encoding='utf-8'),
]
log_handlers += extra_handlers if extra_handlers is not None else [] log_handlers += extra_handlers if extra_handlers is not None else []
for handler in log_handlers: for handler in log_handlers:
@@ -54,13 +58,13 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging.
handler.setFormatter(color_formatter) handler.setFormatter(color_formatter)
qcg_logger.addHandler(handler) qcg_logger.addHandler(handler)
qcg_logger.debug("日志初始化完成,日志级别:%s" % level) qcg_logger.debug('日志初始化完成,日志级别:%s' % level)
logging.basicConfig( logging.basicConfig(
level=logging.CRITICAL, # 设置日志输出格式 level=logging.CRITICAL, # 设置日志输出格式
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s", format='[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s',
# 日志输出的格式 # 日志输出的格式
# -8表示占位符,让输出左对齐,输出长度都为8位 # -8表示占位符,让输出左对齐,输出长度都为8位
datefmt="%Y-%m-%d %H:%M:%S", # 时间输出的格式 datefmt='%Y-%m-%d %H:%M:%S', # 时间输出的格式
handlers=[logging.NullHandler()], handlers=[logging.NullHandler()],
) )
+26 -15
View File
@@ -8,21 +8,18 @@ import asyncio
import pydantic.v1 as pydantic import pydantic.v1 as pydantic
from ..provider import entities as llm_entities from ..provider import entities as llm_entities
from ..provider.modelmgr import entities, modelmgr, requester from ..provider.modelmgr import requester
from ..provider.tools import entities as tools_entities from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter from ..platform import adapter as msadapter
from ..platform.types import message as platform_message from ..platform.types import message as platform_message
from ..platform.types import events as platform_events from ..platform.types import events as platform_events
from ..platform.types import entities as platform_entities
class LifecycleControlScope(enum.Enum): class LifecycleControlScope(enum.Enum):
APPLICATION = 'application'
APPLICATION = "application" PLATFORM = 'platform'
PLATFORM = "platform" PLUGIN = 'plugin'
PLUGIN = "plugin" PROVIDER = 'provider'
PROVIDER = "provider"
class LauncherTypes(enum.Enum): class LauncherTypes(enum.Enum):
@@ -89,14 +86,17 @@ class Query(pydantic.BaseModel):
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器阶段设置""" """使用的函数,由前置处理器阶段设置"""
resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]] = [] resp_messages: (
typing.Optional[list[llm_entities.Message]]
| typing.Optional[list[platform_message.MessageChain]]
) = []
"""由Process阶段生成的回复消息对象列表""" """由Process阶段生成的回复消息对象列表"""
resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None
"""回复消息链,从resp_messages包装而得""" """回复消息链,从resp_messages包装而得"""
# ======= 内部保留 ======= # ======= 内部保留 =======
current_stage: "pkg.pipeline.pipelinemgr.StageInstContainer" = None current_stage = None # pkg.pipeline.pipelinemgr.StageInstContainer
"""当前所处阶段""" """当前所处阶段"""
class Config: class Config:
@@ -130,9 +130,13 @@ class Conversation(pydantic.BaseModel):
messages: list[llm_entities.Message] messages: list[llm_entities.Message]
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) create_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) update_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
use_llm_model: requester.RuntimeLLMModel use_llm_model: requester.RuntimeLLMModel
@@ -147,6 +151,7 @@ class Conversation(pydantic.BaseModel):
class Session(pydantic.BaseModel): class Session(pydantic.BaseModel):
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}""" """会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""
launcher_type: LauncherTypes launcher_type: LauncherTypes
launcher_id: typing.Union[int, str] launcher_id: typing.Union[int, str]
@@ -157,11 +162,17 @@ class Session(pydantic.BaseModel):
using_conversation: typing.Optional[Conversation] = None using_conversation: typing.Optional[Conversation] = None
conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list) conversations: typing.Optional[list[Conversation]] = pydantic.Field(
default_factory=list
)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) create_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) update_time: typing.Optional[datetime.datetime] = pydantic.Field(
default_factory=datetime.datetime.now
)
semaphore: typing.Optional[asyncio.Semaphore] = None semaphore: typing.Optional[asyncio.Semaphore] = None
"""当前会话的信号量,用于限制并发""" """当前会话的信号量,用于限制并发"""
+6 -8
View File
@@ -9,9 +9,10 @@ from . import app
preregistered_migrations: list[typing.Type[Migration]] = [] preregistered_migrations: list[typing.Type[Migration]] = []
"""当前阶段暂不支持扩展""" """当前阶段暂不支持扩展"""
def migration_class(name: str, number: int): def migration_class(name: str, number: int):
"""注册一个迁移 """注册一个迁移"""
"""
def decorator(cls: typing.Type[Migration]) -> typing.Type[Migration]: def decorator(cls: typing.Type[Migration]) -> typing.Type[Migration]:
cls.name = name cls.name = name
cls.number = number cls.number = number
@@ -22,8 +23,7 @@ def migration_class(name: str, number: int):
class Migration(abc.ABC): class Migration(abc.ABC):
"""一个版本的迁移 """一个版本的迁移"""
"""
name: str name: str
@@ -36,12 +36,10 @@ class Migration(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
async def need_migrate(self) -> bool: async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移 """判断当前环境是否需要运行此迁移"""
"""
pass pass
@abc.abstractmethod @abc.abstractmethod
async def run(self): async def run(self):
"""执行迁移 """执行迁移"""
"""
pass pass
@@ -1,26 +1,26 @@
from __future__ import annotations from __future__ import annotations
import os import os
import sys
from .. import migration from .. import migration
@migration.migration_class("sensitive-word-migration", 1) @migration.migration_class('sensitive-word-migration', 1)
class SensitiveWordMigration(migration.Migration): class SensitiveWordMigration(migration.Migration):
"""敏感词迁移 """敏感词迁移"""
"""
async def need_migrate(self) -> bool: async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移 """判断当前环境是否需要运行此迁移"""
""" return os.path.exists(
return os.path.exists("data/config/sensitive-words.json") and not os.path.exists("data/metadata/sensitive-words.json") 'data/config/sensitive-words.json'
) and not os.path.exists('data/metadata/sensitive-words.json')
async def run(self): async def run(self):
"""执行迁移 """执行迁移"""
"""
# 移动文件 # 移动文件
os.rename("data/config/sensitive-words.json", "data/metadata/sensitive-words.json") os.rename(
'data/config/sensitive-words.json', 'data/metadata/sensitive-words.json'
)
# 重新加载配置 # 重新加载配置
await self.ap.sensitive_meta.load_config() await self.ap.sensitive_meta.load_config()
@@ -3,19 +3,16 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("openai-config-migration", 2) @migration.migration_class('openai-config-migration', 2)
class OpenAIConfigMigration(migration.Migration): class OpenAIConfigMigration(migration.Migration):
"""OpenAI配置迁移 """OpenAI配置迁移"""
"""
async def need_migrate(self) -> bool: async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移 """判断当前环境是否需要运行此迁移"""
"""
return 'openai-config' in self.ap.provider_cfg.data return 'openai-config' in self.ap.provider_cfg.data
async def run(self): async def run(self):
"""执行迁移 """执行迁移"""
"""
old_openai_config = self.ap.provider_cfg.data['openai-config'].copy() old_openai_config = self.ap.provider_cfg.data['openai-config'].copy()
if 'keys' not in self.ap.provider_cfg.data: if 'keys' not in self.ap.provider_cfg.data:
@@ -26,7 +23,9 @@ class OpenAIConfigMigration(migration.Migration):
self.ap.provider_cfg.data['keys']['openai'] = old_openai_config['api-keys'] self.ap.provider_cfg.data['keys']['openai'] = old_openai_config['api-keys']
self.ap.provider_cfg.data['model'] = old_openai_config['chat-completions-params']['model'] self.ap.provider_cfg.data['model'] = old_openai_config[
'chat-completions-params'
]['model']
del old_openai_config['chat-completions-params']['model'] del old_openai_config['chat-completions-params']['model']
@@ -3,26 +3,23 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("anthropic-requester-config-completion", 3) @migration.migration_class('anthropic-requester-config-completion', 3)
class AnthropicRequesterConfigCompletionMigration(migration.Migration): class AnthropicRequesterConfigCompletionMigration(migration.Migration):
"""OpenAI配置迁移 """OpenAI配置迁移"""
"""
async def need_migrate(self) -> bool: async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移 """判断当前环境是否需要运行此迁移"""
""" return (
return 'anthropic-messages' not in self.ap.provider_cfg.data['requester'] \ 'anthropic-messages' not in self.ap.provider_cfg.data['requester']
or 'anthropic' not in self.ap.provider_cfg.data['keys'] or 'anthropic' not in self.ap.provider_cfg.data['keys']
)
async def run(self): async def run(self):
"""执行迁移 """执行迁移"""
"""
if 'anthropic-messages' not in self.ap.provider_cfg.data['requester']: if 'anthropic-messages' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['anthropic-messages'] = { self.ap.provider_cfg.data['requester']['anthropic-messages'] = {
'base-url': 'https://api.anthropic.com', 'base-url': 'https://api.anthropic.com',
'args': { 'args': {'max_tokens': 1024},
'max_tokens': 1024
},
'timeout': 120, 'timeout': 120,
} }
@@ -3,20 +3,19 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("moonshot-config-completion", 4) @migration.migration_class('moonshot-config-completion', 4)
class MoonshotConfigCompletionMigration(migration.Migration): class MoonshotConfigCompletionMigration(migration.Migration):
"""OpenAI配置迁移 """OpenAI配置迁移"""
"""
async def need_migrate(self) -> bool: async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移 """判断当前环境是否需要运行此迁移"""
""" return (
return 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester'] \ 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester']
or 'moonshot' not in self.ap.provider_cfg.data['keys'] or 'moonshot' not in self.ap.provider_cfg.data['keys']
)
async def run(self): async def run(self):
"""执行迁移 """执行迁移"""
"""
if 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester']: if 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['moonshot-chat-completions'] = { self.ap.provider_cfg.data['requester']['moonshot-chat-completions'] = {
'base-url': 'https://api.moonshot.cn/v1', 'base-url': 'https://api.moonshot.cn/v1',
@@ -3,20 +3,19 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("deepseek-config-completion", 5) @migration.migration_class('deepseek-config-completion', 5)
class DeepseekConfigCompletionMigration(migration.Migration): class DeepseekConfigCompletionMigration(migration.Migration):
"""OpenAI配置迁移 """OpenAI配置迁移"""
"""
async def need_migrate(self) -> bool: async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移 """判断当前环境是否需要运行此迁移"""
""" return (
return 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester'] \ 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester']
or 'deepseek' not in self.ap.provider_cfg.data['keys'] or 'deepseek' not in self.ap.provider_cfg.data['keys']
)
async def run(self): async def run(self):
"""执行迁移 """执行迁移"""
"""
if 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester']: if 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['deepseek-chat-completions'] = { self.ap.provider_cfg.data['requester']['deepseek-chat-completions'] = {
'base-url': 'https://api.deepseek.com', 'base-url': 'https://api.deepseek.com',
+4 -4
View File
@@ -3,17 +3,17 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("vision-config", 6) @migration.migration_class('vision-config', 6)
class VisionConfigMigration(migration.Migration): class VisionConfigMigration(migration.Migration):
"""迁移""" """迁移"""
async def need_migrate(self) -> bool: async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移""" """判断当前环境是否需要运行此迁移"""
return "enable-vision" not in self.ap.provider_cfg.data return 'enable-vision' not in self.ap.provider_cfg.data
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
if "enable-vision" not in self.ap.provider_cfg.data: if 'enable-vision' not in self.ap.provider_cfg.data:
self.ap.provider_cfg.data["enable-vision"] = False self.ap.provider_cfg.data['enable-vision'] = False
await self.ap.provider_cfg.dump_config() await self.ap.provider_cfg.dump_config()
+6 -4
View File
@@ -3,18 +3,20 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("qcg-center-url-config", 7) @migration.migration_class('qcg-center-url-config', 7)
class QCGCenterURLConfigMigration(migration.Migration): class QCGCenterURLConfigMigration(migration.Migration):
"""迁移""" """迁移"""
async def need_migrate(self) -> bool: async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移""" """判断当前环境是否需要运行此迁移"""
return "qcg-center-url" not in self.ap.system_cfg.data return 'qcg-center-url' not in self.ap.system_cfg.data
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
if "qcg-center-url" not in self.ap.system_cfg.data: if 'qcg-center-url' not in self.ap.system_cfg.data:
self.ap.system_cfg.data["qcg-center-url"] = "https://api.qchatgpt.rockchin.top/api/v2" self.ap.system_cfg.data['qcg-center-url'] = (
'https://api.qchatgpt.rockchin.top/api/v2'
)
await self.ap.system_cfg.dump_config() await self.ap.system_cfg.dump_config()
@@ -3,27 +3,27 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("ad-fixwin-cfg-migration", 8) @migration.migration_class('ad-fixwin-cfg-migration', 8)
class AdFixwinConfigMigration(migration.Migration): class AdFixwinConfigMigration(migration.Migration):
"""迁移""" """迁移"""
async def need_migrate(self) -> bool: async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移""" """判断当前环境是否需要运行此迁移"""
return isinstance( return isinstance(
self.ap.pipeline_cfg.data["rate-limit"]["fixwin"]["default"], self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default'], int
int
) )
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
for session_name in self.ap.pipeline_cfg.data["rate-limit"]["fixwin"]: for session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
temp_dict = { temp_dict = {
"window-size": 60, 'window-size': 60,
"limit": self.ap.pipeline_cfg.data["rate-limit"]["fixwin"][session_name] 'limit': self.ap.pipeline_cfg.data['rate-limit']['fixwin'][
session_name
],
} }
self.ap.pipeline_cfg.data["rate-limit"]["fixwin"][session_name] = temp_dict self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name] = temp_dict
await self.ap.pipeline_cfg.dump_config() await self.ap.pipeline_cfg.dump_config()
@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("msg-truncator-cfg-migration", 9) @migration.migration_class('msg-truncator-cfg-migration', 9)
class MsgTruncatorConfigMigration(migration.Migration): class MsgTruncatorConfigMigration(migration.Migration):
"""迁移""" """迁移"""
@@ -16,9 +16,7 @@ class MsgTruncatorConfigMigration(migration.Migration):
self.ap.pipeline_cfg.data['msg-truncate'] = { self.ap.pipeline_cfg.data['msg-truncate'] = {
'method': 'round', 'method': 'round',
'round': { 'round': {'max-round': 10},
'max-round': 10
}
} }
await self.ap.pipeline_cfg.dump_config() await self.ap.pipeline_cfg.dump_config()
@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("ollama-requester-config", 10) @migration.migration_class('ollama-requester-config', 10)
class MsgTruncatorConfigMigration(migration.Migration): class MsgTruncatorConfigMigration(migration.Migration):
"""迁移""" """迁移"""
@@ -15,9 +15,9 @@ class MsgTruncatorConfigMigration(migration.Migration):
"""执行迁移""" """执行迁移"""
self.ap.provider_cfg.data['requester']['ollama-chat'] = { self.ap.provider_cfg.data['requester']['ollama-chat'] = {
"base-url": "http://127.0.0.1:11434", 'base-url': 'http://127.0.0.1:11434',
"args": {}, 'args': {},
"timeout": 600 'timeout': 600,
} }
await self.ap.provider_cfg.dump_config() await self.ap.provider_cfg.dump_config()
@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("command-prefix-config", 11) @migration.migration_class('command-prefix-config', 11)
class CommandPrefixConfigMigration(migration.Migration): class CommandPrefixConfigMigration(migration.Migration):
"""迁移""" """迁移"""
@@ -14,8 +14,6 @@ class CommandPrefixConfigMigration(migration.Migration):
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
self.ap.command_cfg.data['command-prefix'] = [ self.ap.command_cfg.data['command-prefix'] = ['!', '']
"!", ""
]
await self.ap.command_cfg.dump_config() await self.ap.command_cfg.dump_config()
+1 -1
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("runner-config", 12) @migration.migration_class('runner-config', 12)
class RunnerConfigMigration(migration.Migration): class RunnerConfigMigration(migration.Migration):
"""迁移""" """迁移"""
+11 -10
View File
@@ -3,29 +3,30 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("http-api-config", 13) @migration.migration_class('http-api-config', 13)
class HttpApiConfigMigration(migration.Migration): class HttpApiConfigMigration(migration.Migration):
"""迁移""" """迁移"""
async def need_migrate(self) -> bool: async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移""" """判断当前环境是否需要运行此迁移"""
return 'http-api' not in self.ap.system_cfg.data or "persistence" not in self.ap.system_cfg.data return (
'http-api' not in self.ap.system_cfg.data
or 'persistence' not in self.ap.system_cfg.data
)
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
self.ap.system_cfg.data['http-api'] = { self.ap.system_cfg.data['http-api'] = {
"enable": True, 'enable': True,
"host": "0.0.0.0", 'host': '0.0.0.0',
"port": 5300, 'port': 5300,
"jwt-expire": 604800 'jwt-expire': 604800,
} }
self.ap.system_cfg.data['persistence'] = { self.ap.system_cfg.data['persistence'] = {
"sqlite": { 'sqlite': {'path': 'data/persistence.db'},
"path": "data/persistence.db" 'use': 'sqlite',
},
"use": "sqlite"
} }
await self.ap.system_cfg.dump_config() await self.ap.system_cfg.dump_config()
@@ -3,20 +3,20 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("force-delay-config", 14) @migration.migration_class('force-delay-config', 14)
class ForceDelayConfigMigration(migration.Migration): class ForceDelayConfigMigration(migration.Migration):
"""迁移""" """迁移"""
async def need_migrate(self) -> bool: async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移""" """判断当前环境是否需要运行此迁移"""
return type(self.ap.platform_cfg.data['force-delay']) == list return isinstance(self.ap.platform_cfg.data['force-delay'], list)
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
self.ap.platform_cfg.data['force-delay'] = { self.ap.platform_cfg.data['force-delay'] = {
"min": self.ap.platform_cfg.data['force-delay'][0], 'min': self.ap.platform_cfg.data['force-delay'][0],
"max": self.ap.platform_cfg.data['force-delay'][1] 'max': self.ap.platform_cfg.data['force-delay'][1],
} }
await self.ap.platform_cfg.dump_config() await self.ap.platform_cfg.dump_config()
+9 -8
View File
@@ -3,24 +3,25 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("gitee-ai-config", 15) @migration.migration_class('gitee-ai-config', 15)
class GiteeAIConfigMigration(migration.Migration): class GiteeAIConfigMigration(migration.Migration):
"""迁移""" """迁移"""
async def need_migrate(self) -> bool: async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移""" """判断当前环境是否需要运行此迁移"""
return 'gitee-ai-chat-completions' not in self.ap.provider_cfg.data['requester'] or 'gitee-ai' not in self.ap.provider_cfg.data['keys'] return (
'gitee-ai-chat-completions' not in self.ap.provider_cfg.data['requester']
or 'gitee-ai' not in self.ap.provider_cfg.data['keys']
)
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
self.ap.provider_cfg.data['requester']['gitee-ai-chat-completions'] = { self.ap.provider_cfg.data['requester']['gitee-ai-chat-completions'] = {
"base-url": "https://ai.gitee.com/v1", 'base-url': 'https://ai.gitee.com/v1',
"args": {}, 'args': {},
"timeout": 120 'timeout': 120,
} }
self.ap.provider_cfg.data['keys']['gitee-ai'] = [ self.ap.provider_cfg.data['keys']['gitee-ai'] = ['XXXXX']
"XXXXX"
]
await self.ap.provider_cfg.dump_config() await self.ap.provider_cfg.dump_config()
+5 -10
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("dify-service-api-config", 16) @migration.migration_class('dify-service-api-config', 16)
class DifyServiceAPICfgMigration(migration.Migration): class DifyServiceAPICfgMigration(migration.Migration):
"""迁移""" """迁移"""
@@ -14,15 +14,10 @@ class DifyServiceAPICfgMigration(migration.Migration):
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
self.ap.provider_cfg.data['dify-service-api'] = { self.ap.provider_cfg.data['dify-service-api'] = {
"base-url": "https://api.dify.ai/v1", 'base-url': 'https://api.dify.ai/v1',
"app-type": "chat", 'app-type': 'chat',
"chat": { 'chat': {'api-key': 'app-1234567890'},
"api-key": "app-1234567890" 'workflow': {'api-key': 'app-1234567890', 'output-key': 'summary'},
},
"workflow": {
"api-key": "app-1234567890",
"output-key": "summary"
}
} }
await self.ap.provider_cfg.dump_config() await self.ap.provider_cfg.dump_config()
@@ -3,22 +3,26 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("dify-api-timeout-params", 17) @migration.migration_class('dify-api-timeout-params', 17)
class DifyAPITimeoutParamsMigration(migration.Migration): class DifyAPITimeoutParamsMigration(migration.Migration):
"""迁移""" """迁移"""
async def need_migrate(self) -> bool: async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移""" """判断当前环境是否需要运行此迁移"""
return 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['chat'] or 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['workflow'] \ return (
'timeout' not in self.ap.provider_cfg.data['dify-service-api']['chat']
or 'timeout'
not in self.ap.provider_cfg.data['dify-service-api']['workflow']
or 'agent' not in self.ap.provider_cfg.data['dify-service-api'] or 'agent' not in self.ap.provider_cfg.data['dify-service-api']
)
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
self.ap.provider_cfg.data['dify-service-api']['chat']['timeout'] = 120 self.ap.provider_cfg.data['dify-service-api']['chat']['timeout'] = 120
self.ap.provider_cfg.data['dify-service-api']['workflow']['timeout'] = 120 self.ap.provider_cfg.data['dify-service-api']['workflow']['timeout'] = 120
self.ap.provider_cfg.data['dify-service-api']['agent'] = { self.ap.provider_cfg.data['dify-service-api']['agent'] = {
"api-key": "app-1234567890", 'api-key': 'app-1234567890',
"timeout": 120 'timeout': 120,
} }
await self.ap.provider_cfg.dump_config() await self.ap.provider_cfg.dump_config()
+5 -7
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("xai-config", 18) @migration.migration_class('xai-config', 18)
class XaiConfigMigration(migration.Migration): class XaiConfigMigration(migration.Migration):
"""迁移""" """迁移"""
@@ -14,12 +14,10 @@ class XaiConfigMigration(migration.Migration):
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
self.ap.provider_cfg.data['requester']['xai-chat-completions'] = { self.ap.provider_cfg.data['requester']['xai-chat-completions'] = {
"base-url": "https://api.x.ai/v1", 'base-url': 'https://api.x.ai/v1',
"args": {}, 'args': {},
"timeout": 120 'timeout': 120,
} }
self.ap.provider_cfg.data['keys']['xai'] = [ self.ap.provider_cfg.data['keys']['xai'] = ['xai-1234567890']
"xai-1234567890"
]
await self.ap.provider_cfg.dump_config() await self.ap.provider_cfg.dump_config()
+5 -7
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("zhipuai-config", 19) @migration.migration_class('zhipuai-config', 19)
class ZhipuaiConfigMigration(migration.Migration): class ZhipuaiConfigMigration(migration.Migration):
"""迁移""" """迁移"""
@@ -14,12 +14,10 @@ class ZhipuaiConfigMigration(migration.Migration):
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
self.ap.provider_cfg.data['requester']['zhipuai-chat-completions'] = { self.ap.provider_cfg.data['requester']['zhipuai-chat-completions'] = {
"base-url": "https://open.bigmodel.cn/api/paas/v4", 'base-url': 'https://open.bigmodel.cn/api/paas/v4',
"args": {}, 'args': {},
"timeout": 120 'timeout': 120,
} }
self.ap.provider_cfg.data['keys']['zhipuai'] = [ self.ap.provider_cfg.data['keys']['zhipuai'] = ['xxxxxxx']
"xxxxxxx"
]
await self.ap.provider_cfg.dump_config() await self.ap.provider_cfg.dump_config()
+14 -12
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("wecom-config", 20) @migration.migration_class('wecom-config', 20)
class WecomConfigMigration(migration.Migration): class WecomConfigMigration(migration.Migration):
"""迁移""" """迁移"""
@@ -19,16 +19,18 @@ class WecomConfigMigration(migration.Migration):
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({ self.ap.platform_cfg.data['platform-adapters'].append(
"adapter": "wecom", {
"enable": False, 'adapter': 'wecom',
"host": "0.0.0.0", 'enable': False,
"port": 2290, 'host': '0.0.0.0',
"corpid": "", 'port': 2290,
"secret": "", 'corpid': '',
"token": "", 'secret': '',
"EncodingAESKey": "", 'token': '',
"contacts_secret": "" 'EncodingAESKey': '',
}) 'contacts_secret': '',
}
)
await self.ap.platform_cfg.dump_config() await self.ap.platform_cfg.dump_config()
+13 -11
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("lark-config", 21) @migration.migration_class('lark-config', 21)
class LarkConfigMigration(migration.Migration): class LarkConfigMigration(migration.Migration):
"""迁移""" """迁移"""
@@ -19,15 +19,17 @@ class LarkConfigMigration(migration.Migration):
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({ self.ap.platform_cfg.data['platform-adapters'].append(
"adapter": "lark", {
"enable": False, 'adapter': 'lark',
"app_id": "cli_abcdefgh", 'enable': False,
"app_secret": "XXXXXXXXXX", 'app_id': 'cli_abcdefgh',
"bot_name": "LangBot", 'app_secret': 'XXXXXXXXXX',
"enable-webhook": False, 'bot_name': 'LangBot',
"port": 2285, 'enable-webhook': False,
"encrypt-key": "xxxxxxxxx" 'port': 2285,
}) 'encrypt-key': 'xxxxxxxxx',
}
)
await self.ap.platform_cfg.dump_config() await self.ap.platform_cfg.dump_config()
+4 -4
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("lmstudio-config", 22) @migration.migration_class('lmstudio-config', 22)
class LmStudioConfigMigration(migration.Migration): class LmStudioConfigMigration(migration.Migration):
"""迁移""" """迁移"""
@@ -15,9 +15,9 @@ class LmStudioConfigMigration(migration.Migration):
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
self.ap.provider_cfg.data['requester']['lmstudio-chat-completions'] = { self.ap.provider_cfg.data['requester']['lmstudio-chat-completions'] = {
"base-url": "http://127.0.0.1:1234/v1", 'base-url': 'http://127.0.0.1:1234/v1',
"args": {}, 'args': {},
"timeout": 120 'timeout': 120,
} }
await self.ap.provider_cfg.dump_config() await self.ap.provider_cfg.dump_config()
@@ -3,25 +3,25 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("siliconflow-config", 23) @migration.migration_class('siliconflow-config', 23)
class SiliconFlowConfigMigration(migration.Migration): class SiliconFlowConfigMigration(migration.Migration):
"""迁移""" """迁移"""
async def need_migrate(self) -> bool: async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移""" """判断当前环境是否需要运行此迁移"""
return 'siliconflow-chat-completions' not in self.ap.provider_cfg.data['requester'] return (
'siliconflow-chat-completions' not in self.ap.provider_cfg.data['requester']
)
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
self.ap.provider_cfg.data['keys']['siliconflow'] = [ self.ap.provider_cfg.data['keys']['siliconflow'] = ['xxxxxxx']
"xxxxxxx"
]
self.ap.provider_cfg.data['requester']['siliconflow-chat-completions'] = { self.ap.provider_cfg.data['requester']['siliconflow-chat-completions'] = {
"base-url": "https://api.siliconflow.cn/v1", 'base-url': 'https://api.siliconflow.cn/v1',
"args": {}, 'args': {},
"timeout": 120 'timeout': 120,
} }
await self.ap.provider_cfg.dump_config() await self.ap.provider_cfg.dump_config()
+9 -7
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("discord-config", 24) @migration.migration_class('discord-config', 24)
class DiscordConfigMigration(migration.Migration): class DiscordConfigMigration(migration.Migration):
"""迁移""" """迁移"""
@@ -19,11 +19,13 @@ class DiscordConfigMigration(migration.Migration):
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({ self.ap.platform_cfg.data['platform-adapters'].append(
"adapter": "discord", {
"enable": False, 'adapter': 'discord',
"client_id": "1234567890", 'enable': False,
"token": "XXXXXXXXXX" 'client_id': '1234567890',
}) 'token': 'XXXXXXXXXX',
}
)
await self.ap.platform_cfg.dump_config() await self.ap.platform_cfg.dump_config()
+13 -11
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("gewechat-config", 25) @migration.migration_class('gewechat-config', 25)
class GewechatConfigMigration(migration.Migration): class GewechatConfigMigration(migration.Migration):
"""迁移""" """迁移"""
@@ -19,15 +19,17 @@ class GewechatConfigMigration(migration.Migration):
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({ self.ap.platform_cfg.data['platform-adapters'].append(
"adapter": "gewechat", {
"enable": False, 'adapter': 'gewechat',
"gewechat_url": "http://your-gewechat-server:2531", 'enable': False,
"gewechat_file_url": "http://your-gewechat-server:2532", 'gewechat_url': 'http://your-gewechat-server:2531',
"port": 2286, 'gewechat_file_url': 'http://your-gewechat-server:2532',
"callback_url": "http://your-callback-url:2286/gewechat/callback", 'port': 2286,
"app_id": "", 'callback_url': 'http://your-callback-url:2286/gewechat/callback',
"token": "" 'app_id': '',
}) 'token': '',
}
)
await self.ap.platform_cfg.dump_config() await self.ap.platform_cfg.dump_config()
+11 -9
View File
@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("qqofficial-config", 26) @migration.migration_class('qqofficial-config', 26)
class QQOfficialConfigMigration(migration.Migration): class QQOfficialConfigMigration(migration.Migration):
"""迁移""" """迁移"""
@@ -19,13 +19,15 @@ class QQOfficialConfigMigration(migration.Migration):
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({ self.ap.platform_cfg.data['platform-adapters'].append(
"adapter": "qqofficial", {
"enable": False, 'adapter': 'qqofficial',
"appid": "", 'enable': False,
"secret": "", 'appid': '',
"port": 2284, 'secret': '',
"token": "" 'port': 2284,
}) 'token': '',
}
)
await self.ap.platform_cfg.dump_config() await self.ap.platform_cfg.dump_config()
@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("wx-official-account-config", 27) @migration.migration_class('wx-official-account-config', 27)
class WXOfficialAccountConfigMigration(migration.Migration): class WXOfficialAccountConfigMigration(migration.Migration):
"""迁移""" """迁移"""
@@ -19,15 +19,17 @@ class WXOfficialAccountConfigMigration(migration.Migration):
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
self.ap.platform_cfg.data['platform-adapters'].append({ self.ap.platform_cfg.data['platform-adapters'].append(
"adapter": "officialaccount", {
"enable": False, 'adapter': 'officialaccount',
"token": "", 'enable': False,
"EncodingAESKey": "", 'token': '',
"AppID": "", 'EncodingAESKey': '',
"AppSecret": "", 'AppID': '',
"host": "0.0.0.0", 'AppSecret': '',
"port": 2287 'host': '0.0.0.0',
}) 'port': 2287,
}
)
await self.ap.platform_cfg.dump_config() await self.ap.platform_cfg.dump_config()
@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("bailian-requester-config", 28) @migration.migration_class('bailian-requester-config', 28)
class BailianRequesterConfigMigration(migration.Migration): class BailianRequesterConfigMigration(migration.Migration):
"""迁移""" """迁移"""
@@ -14,14 +14,12 @@ class BailianRequesterConfigMigration(migration.Migration):
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
self.ap.provider_cfg.data['keys']['bailian'] = [ self.ap.provider_cfg.data['keys']['bailian'] = ['sk-xxxxxxx']
"sk-xxxxxxx"
]
self.ap.provider_cfg.data['requester']['bailian-chat-completions'] = { self.ap.provider_cfg.data['requester']['bailian-chat-completions'] = {
"base-url": "https://dashscope.aliyuncs.com/compatible-mode/v1", 'base-url': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
"args": {}, 'args': {},
"timeout": 120 'timeout': 120,
} }
await self.ap.provider_cfg.dump_config() await self.ap.provider_cfg.dump_config()
@@ -3,7 +3,7 @@ from __future__ import annotations
from .. import migration from .. import migration
@migration.migration_class("dashscope-app-api-config", 29) @migration.migration_class('dashscope-app-api-config', 29)
class DashscopeAppAPICfgMigration(migration.Migration): class DashscopeAppAPICfgMigration(migration.Migration):
"""迁移""" """迁移"""
@@ -14,20 +14,14 @@ class DashscopeAppAPICfgMigration(migration.Migration):
async def run(self): async def run(self):
"""执行迁移""" """执行迁移"""
self.ap.provider_cfg.data['dashscope-app-api'] = { self.ap.provider_cfg.data['dashscope-app-api'] = {
"app-type": "agent", 'app-type': 'agent',
"api-key": "sk-1234567890", 'api-key': 'sk-1234567890',
"agent": { 'agent': {'app-id': 'Your_app_id', 'references_quote': '参考资料来自:'},
"app-id": "Your_app_id", 'workflow': {
"references_quote": "参考资料来自:" 'app-id': 'Your_app_id',
'references_quote': '参考资料来自:',
'biz_params': {'city': '北京', 'date': '2023-08-10'},
}, },
"workflow": {
"app-id": "Your_app_id",
"references_quote": "参考资料来自:",
"biz_params": {
"city": "北京",
"date": "2023-08-10"
}
}
} }
await self.ap.provider_cfg.dump_config() await self.ap.provider_cfg.dump_config()

Some files were not shown because too many files have changed in this diff Show More