mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-27 16:04:21 +00:00
style: introduce ruff as linter and formatter (#1356)
* style: remove necessary imports * style: fix F841 * style: fix F401 * style: fix F811 * style: fix E402 * style: fix E721 * style: fix E722 * style: fix E722 * style: fix F541 * style: ruff format * style: all passed * style: add ruff in deps * style: more ignores in ruff.toml * style: add pre-commit
This commit is contained in:
committed by
GitHub
parent
09e70d70e9
commit
209f16af76
@@ -29,6 +29,7 @@ qcapi
|
|||||||
claude.json
|
claude.json
|
||||||
bard.json
|
bard.json
|
||||||
/*yaml
|
/*yaml
|
||||||
|
!.pre-commit-config.yaml
|
||||||
!components.yaml
|
!components.yaml
|
||||||
!/docker-compose.yaml
|
!/docker-compose.yaml
|
||||||
data/labels/instance_id.json
|
data/labels/instance_id.json
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
# Ruff version.
|
||||||
|
rev: v0.11.7
|
||||||
|
hooks:
|
||||||
|
# Run the linter.
|
||||||
|
- id: ruff
|
||||||
|
# Run the formatter.
|
||||||
|
- id: ruff-format
|
||||||
@@ -1,2 +1,4 @@
|
|||||||
from .v1 import client
|
from .v1 import client as client
|
||||||
from .v1 import errors
|
from .v1 import errors as errors
|
||||||
|
|
||||||
|
__all__ = ['client', 'errors']
|
||||||
|
|||||||
@@ -8,25 +8,33 @@ import json
|
|||||||
|
|
||||||
class TestDifyClient:
|
class TestDifyClient:
|
||||||
async def test_chat_messages(self):
|
async def test_chat_messages(self):
|
||||||
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL"))
|
cln = client.AsyncDifyServiceClient(
|
||||||
|
api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL')
|
||||||
|
)
|
||||||
|
|
||||||
async for chunk in cln.chat_messages(inputs={}, query="调用工具查看现在几点?", user="test"):
|
async for chunk in cln.chat_messages(
|
||||||
|
inputs={}, query='调用工具查看现在几点?', user='test'
|
||||||
|
):
|
||||||
print(json.dumps(chunk, ensure_ascii=False, indent=4))
|
print(json.dumps(chunk, ensure_ascii=False, indent=4))
|
||||||
|
|
||||||
async def test_upload_file(self):
|
async def test_upload_file(self):
|
||||||
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL"))
|
cln = client.AsyncDifyServiceClient(
|
||||||
|
api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL')
|
||||||
|
)
|
||||||
|
|
||||||
file_bytes = open("img.png", "rb").read()
|
file_bytes = open('img.png', 'rb').read()
|
||||||
|
|
||||||
print(type(file_bytes))
|
print(type(file_bytes))
|
||||||
|
|
||||||
file = ("img2.png", file_bytes, "image/png")
|
file = ('img2.png', file_bytes, 'image/png')
|
||||||
|
|
||||||
resp = await cln.upload_file(file=file, user="test")
|
resp = await cln.upload_file(file=file, user='test')
|
||||||
print(json.dumps(resp, ensure_ascii=False, indent=4))
|
print(json.dumps(resp, ensure_ascii=False, indent=4))
|
||||||
|
|
||||||
async def test_workflow_run(self):
|
async def test_workflow_run(self):
|
||||||
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL"))
|
cln = client.AsyncDifyServiceClient(
|
||||||
|
api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL')
|
||||||
|
)
|
||||||
|
|
||||||
# resp = await cln.workflow_run(inputs={}, user="test")
|
# resp = await cln.workflow_run(inputs={}, user="test")
|
||||||
# # print(json.dumps(resp, ensure_ascii=False, indent=4))
|
# # print(json.dumps(resp, ensure_ascii=False, indent=4))
|
||||||
@@ -34,11 +42,12 @@ class TestDifyClient:
|
|||||||
chunks = []
|
chunks = []
|
||||||
|
|
||||||
ignored_events = ['text_chunk']
|
ignored_events = ['text_chunk']
|
||||||
async for chunk in cln.workflow_run(inputs={}, user="test"):
|
async for chunk in cln.workflow_run(inputs={}, user='test'):
|
||||||
if chunk['event'] in ignored_events:
|
if chunk['event'] in ignored_events:
|
||||||
continue
|
continue
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
print(json.dumps(chunks, ensure_ascii=False, indent=4))
|
print(json.dumps(chunks, ensure_ascii=False, indent=4))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
if __name__ == '__main__':
|
||||||
asyncio.run(TestDifyClient().test_chat_messages())
|
asyncio.run(TestDifyClient().test_chat_messages())
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class AsyncDifyServiceClient:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
base_url: str = "https://api.dify.ai/v1",
|
base_url: str = 'https://api.dify.ai/v1',
|
||||||
) -> None:
|
) -> None:
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
@@ -26,14 +26,14 @@ class AsyncDifyServiceClient:
|
|||||||
inputs: dict[str, typing.Any],
|
inputs: dict[str, typing.Any],
|
||||||
query: str,
|
query: str,
|
||||||
user: str,
|
user: str,
|
||||||
response_mode: str = "streaming", # 当前不支持 blocking
|
response_mode: str = 'streaming', # 当前不支持 blocking
|
||||||
conversation_id: str = "",
|
conversation_id: str = '',
|
||||||
files: list[dict[str, typing.Any]] = [],
|
files: list[dict[str, typing.Any]] = [],
|
||||||
timeout: float = 30.0,
|
timeout: float = 30.0,
|
||||||
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
|
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
|
||||||
"""发送消息"""
|
"""发送消息"""
|
||||||
if response_mode != "streaming":
|
if response_mode != 'streaming':
|
||||||
raise DifyAPIError("当前仅支持 streaming 模式")
|
raise DifyAPIError('当前仅支持 streaming 模式')
|
||||||
|
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
base_url=self.base_url,
|
base_url=self.base_url,
|
||||||
@@ -41,61 +41,66 @@ class AsyncDifyServiceClient:
|
|||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
) as client:
|
) as client:
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
'POST',
|
||||||
"/chat-messages",
|
'/chat-messages',
|
||||||
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
|
headers={
|
||||||
|
'Authorization': f'Bearer {self.api_key}',
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
json={
|
json={
|
||||||
"inputs": inputs,
|
'inputs': inputs,
|
||||||
"query": query,
|
'query': query,
|
||||||
"user": user,
|
'user': user,
|
||||||
"response_mode": response_mode,
|
'response_mode': response_mode,
|
||||||
"conversation_id": conversation_id,
|
'conversation_id': conversation_id,
|
||||||
"files": files,
|
'files': files,
|
||||||
},
|
},
|
||||||
) as r:
|
) as r:
|
||||||
async for chunk in r.aiter_lines():
|
async for chunk in r.aiter_lines():
|
||||||
if r.status_code != 200:
|
if r.status_code != 200:
|
||||||
raise DifyAPIError(f"{r.status_code} {chunk}")
|
raise DifyAPIError(f'{r.status_code} {chunk}')
|
||||||
if chunk.strip() == "":
|
if chunk.strip() == '':
|
||||||
continue
|
continue
|
||||||
if chunk.startswith("data:"):
|
if chunk.startswith('data:'):
|
||||||
yield json.loads(chunk[5:])
|
yield json.loads(chunk[5:])
|
||||||
|
|
||||||
async def workflow_run(
|
async def workflow_run(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, typing.Any],
|
inputs: dict[str, typing.Any],
|
||||||
user: str,
|
user: str,
|
||||||
response_mode: str = "streaming", # 当前不支持 blocking
|
response_mode: str = 'streaming', # 当前不支持 blocking
|
||||||
files: list[dict[str, typing.Any]] = [],
|
files: list[dict[str, typing.Any]] = [],
|
||||||
timeout: float = 30.0,
|
timeout: float = 30.0,
|
||||||
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
|
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
|
||||||
"""运行工作流"""
|
"""运行工作流"""
|
||||||
if response_mode != "streaming":
|
if response_mode != 'streaming':
|
||||||
raise DifyAPIError("当前仅支持 streaming 模式")
|
raise DifyAPIError('当前仅支持 streaming 模式')
|
||||||
|
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
base_url=self.base_url,
|
base_url=self.base_url,
|
||||||
trust_env=True,
|
trust_env=True,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
) as client:
|
) as client:
|
||||||
|
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
'POST',
|
||||||
"/workflows/run",
|
'/workflows/run',
|
||||||
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
|
headers={
|
||||||
|
'Authorization': f'Bearer {self.api_key}',
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
json={
|
json={
|
||||||
"inputs": inputs,
|
'inputs': inputs,
|
||||||
"user": user,
|
'user': user,
|
||||||
"response_mode": response_mode,
|
'response_mode': response_mode,
|
||||||
"files": files,
|
'files': files,
|
||||||
},
|
},
|
||||||
) as r:
|
) as r:
|
||||||
async for chunk in r.aiter_lines():
|
async for chunk in r.aiter_lines():
|
||||||
if r.status_code != 200:
|
if r.status_code != 200:
|
||||||
raise DifyAPIError(f"{r.status_code} {chunk}")
|
raise DifyAPIError(f'{r.status_code} {chunk}')
|
||||||
if chunk.strip() == "":
|
if chunk.strip() == '':
|
||||||
continue
|
continue
|
||||||
if chunk.startswith("data:"):
|
if chunk.startswith('data:'):
|
||||||
yield json.loads(chunk[5:])
|
yield json.loads(chunk[5:])
|
||||||
|
|
||||||
async def upload_file(
|
async def upload_file(
|
||||||
@@ -112,15 +117,15 @@ class AsyncDifyServiceClient:
|
|||||||
) as client:
|
) as client:
|
||||||
# multipart/form-data
|
# multipart/form-data
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"/files/upload",
|
'/files/upload',
|
||||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
headers={'Authorization': f'Bearer {self.api_key}'},
|
||||||
files={
|
files={
|
||||||
"file": file,
|
'file': file,
|
||||||
"user": (None, user),
|
'user': (None, user),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 201:
|
if response.status_code != 201:
|
||||||
raise DifyAPIError(f"{response.status_code} {response.text}")
|
raise DifyAPIError(f'{response.status_code} {response.text}')
|
||||||
|
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|||||||
@@ -7,11 +7,11 @@ import os
|
|||||||
|
|
||||||
class TestDifyClient:
|
class TestDifyClient:
|
||||||
async def test_chat_messages(self):
|
async def test_chat_messages(self):
|
||||||
cln = client.DifyClient(api_key=os.getenv("DIFY_API_KEY"))
|
cln = client.DifyClient(api_key=os.getenv('DIFY_API_KEY'))
|
||||||
|
|
||||||
resp = await cln.chat_messages(inputs={}, query="Who are you?", user_id="test")
|
resp = await cln.chat_messages(inputs={}, query='Who are you?', user_id='test')
|
||||||
print(resp)
|
print(resp)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
asyncio.run(TestDifyClient().test_chat_messages())
|
asyncio.run(TestDifyClient().test_chat_messages())
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import dingtalk_stream
|
import dingtalk_stream
|
||||||
from dingtalk_stream import AckMessage
|
from dingtalk_stream import AckMessage
|
||||||
|
|
||||||
|
|
||||||
class EchoTextHandler(dingtalk_stream.ChatbotHandler):
|
class EchoTextHandler(dingtalk_stream.ChatbotHandler):
|
||||||
def __init__(self, client):
|
def __init__(self, client):
|
||||||
self.msg_id = ''
|
self.msg_id = ''
|
||||||
@@ -10,6 +10,7 @@ class EchoTextHandler(dingtalk_stream.ChatbotHandler):
|
|||||||
self.client = client # 用于更新 DingTalkClient 中的 incoming_message
|
self.client = client # 用于更新 DingTalkClient 中的 incoming_message
|
||||||
|
|
||||||
"""处理钉钉消息"""
|
"""处理钉钉消息"""
|
||||||
|
|
||||||
async def process(self, callback: dingtalk_stream.CallbackMessage):
|
async def process(self, callback: dingtalk_stream.CallbackMessage):
|
||||||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
||||||
if incoming_message.message_id != self.msg_id:
|
if incoming_message.message_id != self.msg_id:
|
||||||
@@ -26,6 +27,8 @@ class EchoTextHandler(dingtalk_stream.ChatbotHandler):
|
|||||||
|
|
||||||
return self.incoming_message
|
return self.incoming_message
|
||||||
|
|
||||||
|
|
||||||
async def get_dingtalk_client(client_id, client_secret):
|
async def get_dingtalk_client(client_id, client_secret):
|
||||||
from api import DingTalkClient # 延迟导入,避免循环导入
|
from api import DingTalkClient # 延迟导入,避免循环导入
|
||||||
|
|
||||||
return DingTalkClient(client_id, client_secret)
|
return DingTalkClient(client_id, client_secret)
|
||||||
|
|||||||
+70
-83
@@ -10,7 +10,9 @@ import traceback
|
|||||||
|
|
||||||
|
|
||||||
class DingTalkClient:
|
class DingTalkClient:
|
||||||
def __init__(self, client_id: str, client_secret: str,robot_name:str,robot_code:str):
|
def __init__(
|
||||||
|
self, client_id: str, client_secret: str, robot_name: str, robot_code: str
|
||||||
|
):
|
||||||
"""初始化 WebSocket 连接并自动启动"""
|
"""初始化 WebSocket 连接并自动启动"""
|
||||||
self.credential = dingtalk_stream.Credential(client_id, client_secret)
|
self.credential = dingtalk_stream.Credential(client_id, client_secret)
|
||||||
self.client = dingtalk_stream.DingTalkStreamClient(self.credential)
|
self.client = dingtalk_stream.DingTalkStreamClient(self.credential)
|
||||||
@@ -18,38 +20,32 @@ class DingTalkClient:
|
|||||||
self.secret = client_secret
|
self.secret = client_secret
|
||||||
# 在 DingTalkClient 中传入自己作为参数,避免循环导入
|
# 在 DingTalkClient 中传入自己作为参数,避免循环导入
|
||||||
self.EchoTextHandler = EchoTextHandler(self)
|
self.EchoTextHandler = EchoTextHandler(self)
|
||||||
self.client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler)
|
self.client.register_callback_handler(
|
||||||
|
dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler
|
||||||
|
)
|
||||||
self._message_handlers = {
|
self._message_handlers = {
|
||||||
"example":[],
|
'example': [],
|
||||||
}
|
}
|
||||||
self.access_token = ''
|
self.access_token = ''
|
||||||
self.robot_name = robot_name
|
self.robot_name = robot_name
|
||||||
self.robot_code = robot_code
|
self.robot_code = robot_code
|
||||||
self.access_token_expiry_time = ''
|
self.access_token_expiry_time = ''
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def get_access_token(self):
|
async def get_access_token(self):
|
||||||
url = "https://api.dingtalk.com/v1.0/oauth2/accessToken"
|
url = 'https://api.dingtalk.com/v1.0/oauth2/accessToken'
|
||||||
headers = {
|
headers = {'Content-Type': 'application/json'}
|
||||||
"Content-Type": "application/json"
|
data = {'appKey': self.key, 'appSecret': self.secret}
|
||||||
}
|
|
||||||
data = {
|
|
||||||
"appKey": self.key,
|
|
||||||
"appSecret": self.secret
|
|
||||||
}
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
try:
|
try:
|
||||||
response = await client.post(url,json=data,headers=headers)
|
response = await client.post(url, json=data, headers=headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
self.access_token = response_data.get("accessToken")
|
self.access_token = response_data.get('accessToken')
|
||||||
expires_in = int(response_data.get("expireIn",7200))
|
expires_in = int(response_data.get('expireIn', 7200))
|
||||||
self.access_token_expiry_time = time.time() + expires_in - 60
|
self.access_token_expiry_time = time.time() + expires_in - 60
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(e)
|
raise Exception(e)
|
||||||
|
|
||||||
|
|
||||||
async def is_token_expired(self):
|
async def is_token_expired(self):
|
||||||
"""检查token是否过期"""
|
"""检查token是否过期"""
|
||||||
if self.access_token_expiry_time is None:
|
if self.access_token_expiry_time is None:
|
||||||
@@ -61,62 +57,53 @@ class DingTalkClient:
|
|||||||
return False
|
return False
|
||||||
return bool(self.access_token and self.access_token.strip())
|
return bool(self.access_token and self.access_token.strip())
|
||||||
|
|
||||||
async def download_image(self,download_code:str):
|
async def download_image(self, download_code: str):
|
||||||
if not await self.check_access_token():
|
if not await self.check_access_token():
|
||||||
await self.get_access_token()
|
await self.get_access_token()
|
||||||
url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download'
|
url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download'
|
||||||
params = {
|
params = {'downloadCode': download_code, 'robotCode': self.robot_code}
|
||||||
"downloadCode":download_code,
|
headers = {'x-acs-dingtalk-access-token': self.access_token}
|
||||||
"robotCode":self.robot_code
|
|
||||||
}
|
|
||||||
headers ={
|
|
||||||
"x-acs-dingtalk-access-token": self.access_token
|
|
||||||
}
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(url, headers=headers, json=params)
|
response = await client.post(url, headers=headers, json=params)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
result = response.json()
|
result = response.json()
|
||||||
download_url = result.get("downloadUrl")
|
download_url = result.get('downloadUrl')
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Error: {response.status_code}, {response.text}")
|
raise Exception(f'Error: {response.status_code}, {response.text}')
|
||||||
|
|
||||||
if download_url:
|
if download_url:
|
||||||
return await self.download_url_to_base64(download_url)
|
return await self.download_url_to_base64(download_url)
|
||||||
|
|
||||||
async def download_url_to_base64(self,download_url):
|
async def download_url_to_base64(self, download_url):
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.get(download_url)
|
response = await client.get(download_url)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
|
|
||||||
file_bytes = response.content
|
file_bytes = response.content
|
||||||
base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式
|
base64_str = base64.b64encode(file_bytes).decode(
|
||||||
|
'utf-8'
|
||||||
|
) # 返回字符串格式
|
||||||
return base64_str
|
return base64_str
|
||||||
else:
|
else:
|
||||||
raise Exception("获取文件失败")
|
raise Exception('获取文件失败')
|
||||||
|
|
||||||
async def get_audio_url(self,download_code:str):
|
async def get_audio_url(self, download_code: str):
|
||||||
if not await self.check_access_token():
|
if not await self.check_access_token():
|
||||||
await self.get_access_token()
|
await self.get_access_token()
|
||||||
url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download'
|
url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download'
|
||||||
params = {
|
params = {'downloadCode': download_code, 'robotCode': self.robot_code}
|
||||||
"downloadCode":download_code,
|
headers = {'x-acs-dingtalk-access-token': self.access_token}
|
||||||
"robotCode":self.robot_code
|
|
||||||
}
|
|
||||||
headers ={
|
|
||||||
"x-acs-dingtalk-access-token": self.access_token
|
|
||||||
}
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(url, headers=headers, json=params)
|
response = await client.post(url, headers=headers, json=params)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
result = response.json()
|
result = response.json()
|
||||||
download_url = result.get("downloadUrl")
|
download_url = result.get('downloadUrl')
|
||||||
if download_url:
|
if download_url:
|
||||||
return await self.download_url_to_base64(download_url)
|
return await self.download_url_to_base64(download_url)
|
||||||
else:
|
else:
|
||||||
raise Exception("获取音频失败")
|
raise Exception('获取音频失败')
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Error: {response.status_code}, {response.text}")
|
raise Exception(f'Error: {response.status_code}, {response.text}')
|
||||||
|
|
||||||
async def update_incoming_message(self, message):
|
async def update_incoming_message(self, message):
|
||||||
"""异步更新 DingTalkClient 中的 incoming_message"""
|
"""异步更新 DingTalkClient 中的 incoming_message"""
|
||||||
@@ -126,23 +113,20 @@ class DingTalkClient:
|
|||||||
if event:
|
if event:
|
||||||
await self._handle_message(event)
|
await self._handle_message(event)
|
||||||
|
|
||||||
|
async def send_message(self, content: str, incoming_message):
|
||||||
async def send_message(self,content:str,incoming_message):
|
self.EchoTextHandler.reply_text(content, incoming_message)
|
||||||
self.EchoTextHandler.reply_text(content,incoming_message)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_incoming_message(self):
|
async def get_incoming_message(self):
|
||||||
"""获取收到的消息"""
|
"""获取收到的消息"""
|
||||||
return await self.EchoTextHandler.get_incoming_message()
|
return await self.EchoTextHandler.get_incoming_message()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def on_message(self, msg_type: str):
|
def on_message(self, msg_type: str):
|
||||||
def decorator(func: Callable[[DingTalkEvent], None]):
|
def decorator(func: Callable[[DingTalkEvent], None]):
|
||||||
if msg_type not in self._message_handlers:
|
if msg_type not in self._message_handlers:
|
||||||
self._message_handlers[msg_type] = []
|
self._message_handlers[msg_type] = []
|
||||||
self._message_handlers[msg_type].append(func)
|
self._message_handlers[msg_type].append(func)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
async def _handle_message(self, event: DingTalkEvent):
|
async def _handle_message(self, event: DingTalkEvent):
|
||||||
@@ -154,40 +138,44 @@ class DingTalkClient:
|
|||||||
for handler in self._message_handlers[msg_type]:
|
for handler in self._message_handlers[msg_type]:
|
||||||
await handler(event)
|
await handler(event)
|
||||||
|
|
||||||
|
async def get_message(
|
||||||
async def get_message(self,incoming_message:dingtalk_stream.chatbot.ChatbotMessage):
|
self, incoming_message: dingtalk_stream.chatbot.ChatbotMessage
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# print(json.dumps(incoming_message.to_dict(), indent=4, ensure_ascii=False))
|
# print(json.dumps(incoming_message.to_dict(), indent=4, ensure_ascii=False))
|
||||||
message_data = {
|
message_data = {
|
||||||
"IncomingMessage":incoming_message,
|
'IncomingMessage': incoming_message,
|
||||||
}
|
}
|
||||||
if str(incoming_message.conversation_type) == '1':
|
if str(incoming_message.conversation_type) == '1':
|
||||||
message_data["conversation_type"] = 'FriendMessage'
|
message_data['conversation_type'] = 'FriendMessage'
|
||||||
elif str(incoming_message.conversation_type) == '2':
|
elif str(incoming_message.conversation_type) == '2':
|
||||||
message_data["conversation_type"] = 'GroupMessage'
|
message_data['conversation_type'] = 'GroupMessage'
|
||||||
|
|
||||||
|
|
||||||
if incoming_message.message_type == 'richText':
|
if incoming_message.message_type == 'richText':
|
||||||
|
|
||||||
data = incoming_message.rich_text_content.to_dict()
|
data = incoming_message.rich_text_content.to_dict()
|
||||||
for item in data['richText']:
|
for item in data['richText']:
|
||||||
if 'text' in item:
|
if 'text' in item:
|
||||||
message_data["Content"] = item['text']
|
message_data['Content'] = item['text']
|
||||||
if incoming_message.get_image_list()[0]:
|
if incoming_message.get_image_list()[0]:
|
||||||
message_data["Picture"] = await self.download_image(incoming_message.get_image_list()[0])
|
message_data['Picture'] = await self.download_image(
|
||||||
message_data["Type"] = 'text'
|
incoming_message.get_image_list()[0]
|
||||||
|
)
|
||||||
|
message_data['Type'] = 'text'
|
||||||
|
|
||||||
elif incoming_message.message_type == 'text':
|
elif incoming_message.message_type == 'text':
|
||||||
message_data['Content'] = incoming_message.get_text_list()[0]
|
message_data['Content'] = incoming_message.get_text_list()[0]
|
||||||
|
|
||||||
message_data["Type"] = 'text'
|
message_data['Type'] = 'text'
|
||||||
elif incoming_message.message_type == 'picture':
|
elif incoming_message.message_type == 'picture':
|
||||||
message_data['Picture'] = await self.download_image(incoming_message.get_image_list()[0])
|
message_data['Picture'] = await self.download_image(
|
||||||
|
incoming_message.get_image_list()[0]
|
||||||
|
)
|
||||||
|
|
||||||
message_data['Type'] = 'image'
|
message_data['Type'] = 'image'
|
||||||
elif incoming_message.message_type == 'audio':
|
elif incoming_message.message_type == 'audio':
|
||||||
message_data['Audio'] = await self.get_audio_url(incoming_message.to_dict()['content']['downloadCode'])
|
message_data['Audio'] = await self.get_audio_url(
|
||||||
|
incoming_message.to_dict()['content']['downloadCode']
|
||||||
|
)
|
||||||
|
|
||||||
message_data['Type'] = 'audio'
|
message_data['Type'] = 'audio'
|
||||||
|
|
||||||
@@ -199,50 +187,49 @@ class DingTalkClient:
|
|||||||
|
|
||||||
return message_data
|
return message_data
|
||||||
|
|
||||||
async def send_proactive_message_to_one(self,target_id:str,content:str):
|
async def send_proactive_message_to_one(self, target_id: str, content: str):
|
||||||
if not await self.check_access_token():
|
if not await self.check_access_token():
|
||||||
await self.get_access_token()
|
await self.get_access_token()
|
||||||
|
|
||||||
url = 'https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend'
|
url = 'https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend'
|
||||||
|
|
||||||
headers ={
|
headers = {
|
||||||
"x-acs-dingtalk-access-token":self.access_token,
|
'x-acs-dingtalk-access-token': self.access_token,
|
||||||
"Content-Type":"application/json",
|
'Content-Type': 'application/json',
|
||||||
}
|
}
|
||||||
|
|
||||||
data ={
|
data = {
|
||||||
"robotCode":self.robot_code,
|
'robotCode': self.robot_code,
|
||||||
"userIds":[target_id],
|
'userIds': [target_id],
|
||||||
"msgKey": "sampleText",
|
'msgKey': 'sampleText',
|
||||||
"msgParam": json.dumps({"content":content}),
|
'msgParam': json.dumps({'content': content}),
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(url,headers=headers,json=data)
|
await client.post(url, headers=headers, json=data)
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def send_proactive_message_to_group(self, target_id: str, content: str):
|
||||||
async def send_proactive_message_to_group(self,target_id:str,content:str):
|
|
||||||
if not await self.check_access_token():
|
if not await self.check_access_token():
|
||||||
await self.get_access_token()
|
await self.get_access_token()
|
||||||
|
|
||||||
url = 'https://api.dingtalk.com/v1.0/robot/groupMessages/send'
|
url = 'https://api.dingtalk.com/v1.0/robot/groupMessages/send'
|
||||||
|
|
||||||
headers ={
|
headers = {
|
||||||
"x-acs-dingtalk-access-token":self.access_token,
|
'x-acs-dingtalk-access-token': self.access_token,
|
||||||
"Content-Type":"application/json",
|
'Content-Type': 'application/json',
|
||||||
}
|
}
|
||||||
|
|
||||||
data ={
|
data = {
|
||||||
"robotCode":self.robot_code,
|
'robotCode': self.robot_code,
|
||||||
"openConversationId":target_id,
|
'openConversationId': target_id,
|
||||||
"msgKey": "sampleText",
|
'msgKey': 'sampleText',
|
||||||
"msgParam": json.dumps({"content":content}),
|
'msgParam': json.dumps({'content': content}),
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(url,headers=headers,json=data)
|
await client.post(url, headers=headers, json=data)
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|||||||
@@ -1,41 +1,39 @@
|
|||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
import dingtalk_stream
|
import dingtalk_stream
|
||||||
|
|
||||||
|
|
||||||
class DingTalkEvent(dict):
|
class DingTalkEvent(dict):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_payload(payload: Dict[str, Any]) -> Optional["DingTalkEvent"]:
|
def from_payload(payload: Dict[str, Any]) -> Optional['DingTalkEvent']:
|
||||||
try:
|
try:
|
||||||
event = DingTalkEvent(payload)
|
event = DingTalkEvent(payload)
|
||||||
return event
|
return event
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def content(self):
|
def content(self):
|
||||||
return self.get("Content","")
|
return self.get('Content', '')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def incoming_message(self) -> Optional["dingtalk_stream.chatbot.ChatbotMessage"]:
|
def incoming_message(self) -> Optional['dingtalk_stream.chatbot.ChatbotMessage']:
|
||||||
return self.get("IncomingMessage")
|
return self.get('IncomingMessage')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self):
|
def type(self):
|
||||||
return self.get("Type","")
|
return self.get('Type', '')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def picture(self):
|
def picture(self):
|
||||||
return self.get("Picture","")
|
return self.get('Picture', '')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def audio(self):
|
def audio(self):
|
||||||
return self.get("Audio","")
|
return self.get('Audio', '')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def conversation(self):
|
def conversation(self):
|
||||||
return self.get("conversation_type","")
|
return self.get('conversation_type', '')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def __getattr__(self, key: str) -> Optional[Any]:
|
def __getattr__(self, key: str) -> Optional[Any]:
|
||||||
"""
|
"""
|
||||||
@@ -66,4 +64,4 @@ class DingTalkEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
str: 字符串表示。
|
str: 字符串表示。
|
||||||
"""
|
"""
|
||||||
return f"<DingTalkEvent {super().__repr__()}>"
|
return f'<DingTalkEvent {super().__repr__()}>'
|
||||||
|
|||||||
+115
-105
@@ -1,20 +1,14 @@
|
|||||||
# 微信公众号的加解密算法与企业微信一样,所以直接使用企业微信的加解密算法文件
|
# 微信公众号的加解密算法与企业微信一样,所以直接使用企业微信的加解密算法文件
|
||||||
from collections import deque
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from ..wecom_api.WXBizMsgCrypt3 import WXBizMsgCrypt
|
from ..wecom_api.WXBizMsgCrypt3 import WXBizMsgCrypt
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from quart import Quart,request
|
from quart import Quart, request
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import Callable, Dict, Any
|
from typing import Callable
|
||||||
from .oaevent import OAEvent
|
from .oaevent import OAEvent
|
||||||
import httpx
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
|
||||||
import xml.etree.ElementTree as ET
|
|
||||||
from pkg.platform.sources import officialaccount as oa
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
xml_template = """
|
xml_template = """
|
||||||
@@ -28,9 +22,8 @@ xml_template = """
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class OAClient():
|
class OAClient:
|
||||||
|
def __init__(self, token: str, EncodingAESKey: str, AppID: str, Appsecret: str):
|
||||||
def __init__(self,token:str,EncodingAESKey:str,AppID:str,Appsecret:str):
|
|
||||||
self.token = token
|
self.token = token
|
||||||
self.aes = EncodingAESKey
|
self.aes = EncodingAESKey
|
||||||
self.appid = AppID
|
self.appid = AppID
|
||||||
@@ -38,65 +31,71 @@ class OAClient():
|
|||||||
self.base_url = 'https://api.weixin.qq.com'
|
self.base_url = 'https://api.weixin.qq.com'
|
||||||
self.access_token = ''
|
self.access_token = ''
|
||||||
self.app = Quart(__name__)
|
self.app = Quart(__name__)
|
||||||
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'])
|
self.app.add_url_rule(
|
||||||
|
'/callback/command',
|
||||||
|
'handle_callback',
|
||||||
|
self.handle_callback_request,
|
||||||
|
methods=['GET', 'POST'],
|
||||||
|
)
|
||||||
self._message_handlers = {
|
self._message_handlers = {
|
||||||
"example":[],
|
'example': [],
|
||||||
}
|
}
|
||||||
self.access_token_expiry_time = None
|
self.access_token_expiry_time = None
|
||||||
self.msg_id_map = {}
|
self.msg_id_map = {}
|
||||||
self.generated_content = {}
|
self.generated_content = {}
|
||||||
|
|
||||||
async def handle_callback_request(self):
|
async def handle_callback_request(self):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 每隔100毫秒查询是否生成ai回答
|
# 每隔100毫秒查询是否生成ai回答
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
signature = request.args.get("signature", "")
|
signature = request.args.get('signature', '')
|
||||||
timestamp = request.args.get("timestamp", "")
|
timestamp = request.args.get('timestamp', '')
|
||||||
nonce = request.args.get("nonce", "")
|
nonce = request.args.get('nonce', '')
|
||||||
echostr = request.args.get("echostr", "")
|
echostr = request.args.get('echostr', '')
|
||||||
msg_signature = request.args.get("msg_signature","")
|
msg_signature = request.args.get('msg_signature', '')
|
||||||
if msg_signature is None:
|
if msg_signature is None:
|
||||||
raise Exception("msg_signature不在请求体中")
|
raise Exception('msg_signature不在请求体中')
|
||||||
|
|
||||||
if request.method == 'GET':
|
if request.method == 'GET':
|
||||||
# 校验签名
|
# 校验签名
|
||||||
check_str = "".join(sorted([self.token, timestamp, nonce]))
|
check_str = ''.join(sorted([self.token, timestamp, nonce]))
|
||||||
check_signature = hashlib.sha1(check_str.encode("utf-8")).hexdigest()
|
check_signature = hashlib.sha1(check_str.encode('utf-8')).hexdigest()
|
||||||
|
|
||||||
if check_signature == signature:
|
if check_signature == signature:
|
||||||
return echostr # 验证成功返回echostr
|
return echostr # 验证成功返回echostr
|
||||||
else:
|
else:
|
||||||
raise Exception("拒绝请求")
|
raise Exception('拒绝请求')
|
||||||
elif request.method == "POST":
|
elif request.method == 'POST':
|
||||||
encryt_msg = await request.data
|
encryt_msg = await request.data
|
||||||
wxcpt = WXBizMsgCrypt(self.token,self.aes,self.appid)
|
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
|
||||||
ret,xml_msg = wxcpt.DecryptMsg(encryt_msg,msg_signature,timestamp,nonce)
|
ret, xml_msg = wxcpt.DecryptMsg(
|
||||||
|
encryt_msg, msg_signature, timestamp, nonce
|
||||||
|
)
|
||||||
xml_msg = xml_msg.decode('utf-8')
|
xml_msg = xml_msg.decode('utf-8')
|
||||||
|
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
raise Exception("消息解密失败")
|
raise Exception('消息解密失败')
|
||||||
|
|
||||||
message_data = await self.get_message(xml_msg)
|
message_data = await self.get_message(xml_msg)
|
||||||
if message_data :
|
if message_data:
|
||||||
event = OAEvent.from_payload(message_data)
|
event = OAEvent.from_payload(message_data)
|
||||||
if event:
|
if event:
|
||||||
await self._handle_message(event)
|
await self._handle_message(event)
|
||||||
|
|
||||||
root = ET.fromstring(xml_msg)
|
root = ET.fromstring(xml_msg)
|
||||||
from_user = root.find("FromUserName").text # 发送者
|
from_user = root.find('FromUserName').text # 发送者
|
||||||
to_user = root.find("ToUserName").text # 机器人
|
to_user = root.find('ToUserName').text # 机器人
|
||||||
|
|
||||||
timeout = 4.80
|
timeout = 4.80
|
||||||
interval = 0.1
|
interval = 0.1
|
||||||
while True:
|
while True:
|
||||||
content = self.generated_content.pop(message_data["MsgId"], None)
|
content = self.generated_content.pop(message_data['MsgId'], None)
|
||||||
if content:
|
if content:
|
||||||
response_xml = xml_template.format(
|
response_xml = xml_template.format(
|
||||||
to_user=from_user,
|
to_user=from_user,
|
||||||
from_user=to_user,
|
from_user=to_user,
|
||||||
create_time=int(time.time()),
|
create_time=int(time.time()),
|
||||||
content = content
|
content=content,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response_xml
|
return response_xml
|
||||||
@@ -106,53 +105,56 @@ class OAClient():
|
|||||||
|
|
||||||
await asyncio.sleep(interval)
|
await asyncio.sleep(interval)
|
||||||
|
|
||||||
if self.msg_id_map.get(message_data["MsgId"], 1) == 3:
|
if self.msg_id_map.get(message_data['MsgId'], 1) == 3:
|
||||||
|
|
||||||
# response_xml = xml_template.format(
|
# response_xml = xml_template.format(
|
||||||
# to_user=from_user,
|
# to_user=from_user,
|
||||||
# from_user=to_user,
|
# from_user=to_user,
|
||||||
# create_time=int(time.time()),
|
# create_time=int(time.time()),
|
||||||
# content = "请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。"
|
# content = "请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。"
|
||||||
# )
|
# )
|
||||||
print("请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。")
|
print(
|
||||||
|
'请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。'
|
||||||
|
)
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
async def get_message(self, xml_msg: str):
|
async def get_message(self, xml_msg: str):
|
||||||
|
|
||||||
root = ET.fromstring(xml_msg)
|
root = ET.fromstring(xml_msg)
|
||||||
|
|
||||||
message_data = {
|
message_data = {
|
||||||
"ToUserName": root.find("ToUserName").text,
|
'ToUserName': root.find('ToUserName').text,
|
||||||
"FromUserName": root.find("FromUserName").text,
|
'FromUserName': root.find('FromUserName').text,
|
||||||
"CreateTime": int(root.find("CreateTime").text),
|
'CreateTime': int(root.find('CreateTime').text),
|
||||||
"MsgType": root.find("MsgType").text,
|
'MsgType': root.find('MsgType').text,
|
||||||
"Content": root.find("Content").text if root.find("Content") is not None else None,
|
'Content': root.find('Content').text
|
||||||
"MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None,
|
if root.find('Content') is not None
|
||||||
|
else None,
|
||||||
|
'MsgId': int(root.find('MsgId').text)
|
||||||
|
if root.find('MsgId') is not None
|
||||||
|
else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
return message_data
|
return message_data
|
||||||
|
|
||||||
|
|
||||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
启动 Quart 应用。
|
启动 Quart 应用。
|
||||||
"""
|
"""
|
||||||
await self.app.run_task(host=host, port=port, *args, **kwargs)
|
await self.app.run_task(host=host, port=port, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def on_message(self, msg_type: str):
|
def on_message(self, msg_type: str):
|
||||||
"""
|
"""
|
||||||
注册消息类型处理器。
|
注册消息类型处理器。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[[OAEvent], None]):
|
def decorator(func: Callable[[OAEvent], None]):
|
||||||
if msg_type not in self._message_handlers:
|
if msg_type not in self._message_handlers:
|
||||||
self._message_handlers[msg_type] = []
|
self._message_handlers[msg_type] = []
|
||||||
self._message_handlers[msg_type].append(func)
|
self._message_handlers[msg_type].append(func)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
async def _handle_message(self, event: OAEvent):
|
async def _handle_message(self, event: OAEvent):
|
||||||
@@ -170,14 +172,19 @@ class OAClient():
|
|||||||
for handler in self._message_handlers[msg_type]:
|
for handler in self._message_handlers[msg_type]:
|
||||||
await handler(event)
|
await handler(event)
|
||||||
|
|
||||||
async def set_message(self,msg_id:int,content:str):
|
async def set_message(self, msg_id: int, content: str):
|
||||||
self.generated_content[msg_id] = content
|
self.generated_content[msg_id] = content
|
||||||
|
|
||||||
|
|
||||||
|
class OAClientForLongerResponse:
|
||||||
class OAClientForLongerResponse():
|
def __init__(
|
||||||
|
self,
|
||||||
def __init__(self,token:str,EncodingAESKey:str,AppID:str,Appsecret:str,LoadingMessage:str):
|
token: str,
|
||||||
|
EncodingAESKey: str,
|
||||||
|
AppID: str,
|
||||||
|
Appsecret: str,
|
||||||
|
LoadingMessage: str,
|
||||||
|
):
|
||||||
self.token = token
|
self.token = token
|
||||||
self.aes = EncodingAESKey
|
self.aes = EncodingAESKey
|
||||||
self.appid = AppID
|
self.appid = AppID
|
||||||
@@ -185,9 +192,14 @@ class OAClientForLongerResponse():
|
|||||||
self.base_url = 'https://api.weixin.qq.com'
|
self.base_url = 'https://api.weixin.qq.com'
|
||||||
self.access_token = ''
|
self.access_token = ''
|
||||||
self.app = Quart(__name__)
|
self.app = Quart(__name__)
|
||||||
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'])
|
self.app.add_url_rule(
|
||||||
|
'/callback/command',
|
||||||
|
'handle_callback',
|
||||||
|
self.handle_callback_request,
|
||||||
|
methods=['GET', 'POST'],
|
||||||
|
)
|
||||||
self._message_handlers = {
|
self._message_handlers = {
|
||||||
"example":[],
|
'example': [],
|
||||||
}
|
}
|
||||||
self.access_token_expiry_time = None
|
self.access_token_expiry_time = None
|
||||||
self.loading_message = LoadingMessage
|
self.loading_message = LoadingMessage
|
||||||
@@ -196,50 +208,55 @@ class OAClientForLongerResponse():
|
|||||||
|
|
||||||
async def handle_callback_request(self):
|
async def handle_callback_request(self):
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
signature = request.args.get('signature', '')
|
||||||
signature = request.args.get("signature", "")
|
timestamp = request.args.get('timestamp', '')
|
||||||
timestamp = request.args.get("timestamp", "")
|
nonce = request.args.get('nonce', '')
|
||||||
nonce = request.args.get("nonce", "")
|
echostr = request.args.get('echostr', '')
|
||||||
echostr = request.args.get("echostr", "")
|
msg_signature = request.args.get('msg_signature', '')
|
||||||
msg_signature = request.args.get("msg_signature", "")
|
|
||||||
|
|
||||||
if msg_signature is None:
|
if msg_signature is None:
|
||||||
raise Exception("msg_signature不在请求体中")
|
raise Exception('msg_signature不在请求体中')
|
||||||
|
|
||||||
if request.method == 'GET':
|
if request.method == 'GET':
|
||||||
check_str = "".join(sorted([self.token, timestamp, nonce]))
|
check_str = ''.join(sorted([self.token, timestamp, nonce]))
|
||||||
check_signature = hashlib.sha1(check_str.encode("utf-8")).hexdigest()
|
check_signature = hashlib.sha1(check_str.encode('utf-8')).hexdigest()
|
||||||
return echostr if check_signature == signature else "拒绝请求"
|
return echostr if check_signature == signature else '拒绝请求'
|
||||||
|
|
||||||
elif request.method == "POST":
|
elif request.method == 'POST':
|
||||||
encryt_msg = await request.data
|
encryt_msg = await request.data
|
||||||
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
|
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
|
||||||
ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce)
|
ret, xml_msg = wxcpt.DecryptMsg(
|
||||||
|
encryt_msg, msg_signature, timestamp, nonce
|
||||||
|
)
|
||||||
xml_msg = xml_msg.decode('utf-8')
|
xml_msg = xml_msg.decode('utf-8')
|
||||||
|
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
raise Exception("消息解密失败")
|
raise Exception('消息解密失败')
|
||||||
|
|
||||||
# 解析 XML
|
# 解析 XML
|
||||||
root = ET.fromstring(xml_msg)
|
root = ET.fromstring(xml_msg)
|
||||||
from_user = root.find("FromUserName").text
|
from_user = root.find('FromUserName').text
|
||||||
to_user = root.find("ToUserName").text
|
to_user = root.find('ToUserName').text
|
||||||
|
|
||||||
|
|
||||||
if self.msg_queue.get(from_user) and self.msg_queue[from_user][0]["content"]:
|
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.msg_queue.get(from_user)
|
||||||
|
and self.msg_queue[from_user][0]['content']
|
||||||
|
):
|
||||||
queue_top = self.msg_queue[from_user].pop(0)
|
queue_top = self.msg_queue[from_user].pop(0)
|
||||||
queue_content = queue_top["content"]
|
queue_content = queue_top['content']
|
||||||
|
|
||||||
# 弹出用户消息
|
# 弹出用户消息
|
||||||
if self.user_msg_queue.get(from_user) and self.user_msg_queue[from_user]:
|
if (
|
||||||
|
self.user_msg_queue.get(from_user)
|
||||||
|
and self.user_msg_queue[from_user]
|
||||||
|
):
|
||||||
self.user_msg_queue[from_user].pop(0)
|
self.user_msg_queue[from_user].pop(0)
|
||||||
|
|
||||||
response_xml = xml_template.format(
|
response_xml = xml_template.format(
|
||||||
to_user=from_user,
|
to_user=from_user,
|
||||||
from_user=to_user,
|
from_user=to_user,
|
||||||
create_time=int(time.time()),
|
create_time=int(time.time()),
|
||||||
content=queue_content
|
content=queue_content,
|
||||||
)
|
)
|
||||||
return response_xml
|
return response_xml
|
||||||
|
|
||||||
@@ -248,10 +265,13 @@ class OAClientForLongerResponse():
|
|||||||
to_user=from_user,
|
to_user=from_user,
|
||||||
from_user=to_user,
|
from_user=to_user,
|
||||||
create_time=int(time.time()),
|
create_time=int(time.time()),
|
||||||
content=self.loading_message
|
content=self.loading_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.user_msg_queue.get(from_user) and self.user_msg_queue[from_user][0]["content"]:
|
if (
|
||||||
|
self.user_msg_queue.get(from_user)
|
||||||
|
and self.user_msg_queue[from_user][0]['content']
|
||||||
|
):
|
||||||
return response_xml
|
return response_xml
|
||||||
else:
|
else:
|
||||||
message_data = await self.get_message(xml_msg)
|
message_data = await self.get_message(xml_msg)
|
||||||
@@ -259,54 +279,53 @@ class OAClientForLongerResponse():
|
|||||||
if message_data:
|
if message_data:
|
||||||
event = OAEvent.from_payload(message_data)
|
event = OAEvent.from_payload(message_data)
|
||||||
if event:
|
if event:
|
||||||
self.user_msg_queue.setdefault(from_user,[]).append(
|
self.user_msg_queue.setdefault(from_user, []).append(
|
||||||
{
|
{
|
||||||
"content":event.message,
|
'content': event.message,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
await self._handle_message(event)
|
await self._handle_message(event)
|
||||||
|
|
||||||
return response_xml
|
return response_xml
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def get_message(self, xml_msg: str):
|
async def get_message(self, xml_msg: str):
|
||||||
|
|
||||||
root = ET.fromstring(xml_msg)
|
root = ET.fromstring(xml_msg)
|
||||||
|
|
||||||
message_data = {
|
message_data = {
|
||||||
"ToUserName": root.find("ToUserName").text,
|
'ToUserName': root.find('ToUserName').text,
|
||||||
"FromUserName": root.find("FromUserName").text,
|
'FromUserName': root.find('FromUserName').text,
|
||||||
"CreateTime": int(root.find("CreateTime").text),
|
'CreateTime': int(root.find('CreateTime').text),
|
||||||
"MsgType": root.find("MsgType").text,
|
'MsgType': root.find('MsgType').text,
|
||||||
"Content": root.find("Content").text if root.find("Content") is not None else None,
|
'Content': root.find('Content').text
|
||||||
"MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None,
|
if root.find('Content') is not None
|
||||||
|
else None,
|
||||||
|
'MsgId': int(root.find('MsgId').text)
|
||||||
|
if root.find('MsgId') is not None
|
||||||
|
else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
return message_data
|
return message_data
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
启动 Quart 应用。
|
启动 Quart 应用。
|
||||||
"""
|
"""
|
||||||
await self.app.run_task(host=host, port=port, *args, **kwargs)
|
await self.app.run_task(host=host, port=port, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def on_message(self, msg_type: str):
|
def on_message(self, msg_type: str):
|
||||||
"""
|
"""
|
||||||
注册消息类型处理器。
|
注册消息类型处理器。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[[OAEvent], None]):
|
def decorator(func: Callable[[OAEvent], None]):
|
||||||
if msg_type not in self._message_handlers:
|
if msg_type not in self._message_handlers:
|
||||||
self._message_handlers[msg_type] = []
|
self._message_handlers[msg_type] = []
|
||||||
self._message_handlers[msg_type].append(func)
|
self._message_handlers[msg_type].append(func)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
async def _handle_message(self, event: OAEvent):
|
async def _handle_message(self, event: OAEvent):
|
||||||
@@ -319,22 +338,13 @@ class OAClientForLongerResponse():
|
|||||||
for handler in self._message_handlers[msg_type]:
|
for handler in self._message_handlers[msg_type]:
|
||||||
await handler(event)
|
await handler(event)
|
||||||
|
|
||||||
async def set_message(self,from_user:int,message_id:int,content:str):
|
async def set_message(self, from_user: int, message_id: int, content: str):
|
||||||
if from_user not in self.msg_queue:
|
if from_user not in self.msg_queue:
|
||||||
self.msg_queue[from_user] = []
|
self.msg_queue[from_user] = []
|
||||||
|
|
||||||
self.msg_queue[from_user].append(
|
self.msg_queue[from_user].append(
|
||||||
{
|
{
|
||||||
"msg_id":message_id,
|
'msg_id': message_id,
|
||||||
"content":content,
|
'content': content,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ class OAEvent(dict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_payload(payload: Dict[str, Any]) -> Optional["OAEvent"]:
|
def from_payload(payload: Dict[str, Any]) -> Optional['OAEvent']:
|
||||||
"""
|
"""
|
||||||
从微信公众号事件数据构造 `WecomEvent` 对象。
|
从微信公众号事件数据构造 `WecomEvent` 对象。
|
||||||
|
|
||||||
@@ -34,14 +34,14 @@ class OAEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
str: 事件类型。
|
str: 事件类型。
|
||||||
"""
|
"""
|
||||||
return self.get("MsgType", "")
|
return self.get('MsgType', '')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def picurl(self) -> str:
|
def picurl(self) -> str:
|
||||||
"""
|
"""
|
||||||
图片链接
|
图片链接
|
||||||
"""
|
"""
|
||||||
return self.get("PicUrl","")
|
return self.get('PicUrl', '')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def detail_type(self) -> str:
|
def detail_type(self) -> str:
|
||||||
@@ -53,8 +53,8 @@ class OAEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
str: 事件详细类型。
|
str: 事件详细类型。
|
||||||
"""
|
"""
|
||||||
if self.type == "event":
|
if self.type == 'event':
|
||||||
return self.get("Event", "")
|
return self.get('Event', '')
|
||||||
return self.type
|
return self.type
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -65,15 +65,14 @@ class OAEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
str: 事件名。
|
str: 事件名。
|
||||||
"""
|
"""
|
||||||
return f"{self.type}.{self.detail_type}"
|
return f'{self.type}.{self.detail_type}'
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def user_id(self) -> Optional[str]:
|
def user_id(self) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
发送方账号
|
发送方账号
|
||||||
"""
|
"""
|
||||||
return self.get("FromUserName")
|
return self.get('FromUserName')
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def receiver_id(self) -> Optional[str]:
|
def receiver_id(self) -> Optional[str]:
|
||||||
@@ -83,7 +82,7 @@ class OAEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[str]: 接收者 ID。
|
Optional[str]: 接收者 ID。
|
||||||
"""
|
"""
|
||||||
return self.get("ToUserName")
|
return self.get('ToUserName')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def message_id(self) -> Optional[str]:
|
def message_id(self) -> Optional[str]:
|
||||||
@@ -93,7 +92,7 @@ class OAEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[str]: 消息 ID。
|
Optional[str]: 消息 ID。
|
||||||
"""
|
"""
|
||||||
return self.get("MsgId")
|
return self.get('MsgId')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def message(self) -> Optional[str]:
|
def message(self) -> Optional[str]:
|
||||||
@@ -103,7 +102,7 @@ class OAEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[str]: 消息内容。
|
Optional[str]: 消息内容。
|
||||||
"""
|
"""
|
||||||
return self.get("Content")
|
return self.get('Content')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def media_id(self) -> Optional[str]:
|
def media_id(self) -> Optional[str]:
|
||||||
@@ -113,7 +112,7 @@ class OAEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[str]: 媒体文件 ID。
|
Optional[str]: 媒体文件 ID。
|
||||||
"""
|
"""
|
||||||
return self.get("MediaId")
|
return self.get('MediaId')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def timestamp(self) -> Optional[int]:
|
def timestamp(self) -> Optional[int]:
|
||||||
@@ -123,7 +122,7 @@ class OAEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[int]: 时间戳。
|
Optional[int]: 时间戳。
|
||||||
"""
|
"""
|
||||||
return self.get("CreateTime")
|
return self.get('CreateTime')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def event_key(self) -> Optional[str]:
|
def event_key(self) -> Optional[str]:
|
||||||
@@ -133,7 +132,7 @@ class OAEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[str]: 事件 Key。
|
Optional[str]: 事件 Key。
|
||||||
"""
|
"""
|
||||||
return self.get("EventKey")
|
return self.get('EventKey')
|
||||||
|
|
||||||
def __getattr__(self, key: str) -> Optional[Any]:
|
def __getattr__(self, key: str) -> Optional[Any]:
|
||||||
"""
|
"""
|
||||||
@@ -164,4 +163,4 @@ class OAEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
str: 字符串表示。
|
str: 字符串表示。
|
||||||
"""
|
"""
|
||||||
return f"<WecomEvent {super().__repr__()}>"
|
return f'<WecomEvent {super().__repr__()}>'
|
||||||
|
|||||||
+96
-104
@@ -1,24 +1,16 @@
|
|||||||
import time
|
import time
|
||||||
from quart import request
|
from quart import request
|
||||||
import base64
|
|
||||||
import binascii
|
|
||||||
import httpx
|
import httpx
|
||||||
from quart import Quart
|
from quart import Quart
|
||||||
import xml.etree.ElementTree as ET
|
|
||||||
from typing import Callable, Dict, Any
|
from typing import Callable, Dict, Any
|
||||||
from pkg.platform.types import events as platform_events, message as platform_message
|
from pkg.platform.types import events as platform_events
|
||||||
import aiofiles
|
|
||||||
from .qqofficialevent import QQOfficialEvent
|
from .qqofficialevent import QQOfficialEvent
|
||||||
import json
|
import json
|
||||||
import hmac
|
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
import traceback
|
import traceback
|
||||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||||
from .qqofficialevent import QQOfficialEvent
|
|
||||||
|
|
||||||
def handle_validation(body: dict, bot_secret: str):
|
def handle_validation(body: dict, bot_secret: str):
|
||||||
|
|
||||||
# bot正确的secert是32位的,此处仅为了适配演示demo
|
# bot正确的secert是32位的,此处仅为了适配演示demo
|
||||||
while len(bot_secret) < 32:
|
while len(bot_secret) < 32:
|
||||||
bot_secret = bot_secret * 2
|
bot_secret = bot_secret * 2
|
||||||
@@ -36,29 +28,26 @@ def handle_validation(body: dict, bot_secret: str):
|
|||||||
|
|
||||||
signature_hex = signature.hex()
|
signature_hex = signature.hex()
|
||||||
|
|
||||||
response = {
|
response = {'plain_token': body['d']['plain_token'], 'signature': signature_hex}
|
||||||
"plain_token": body['d']['plain_token'],
|
|
||||||
"signature": signature_hex
|
|
||||||
}
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
class QQOfficialClient:
|
class QQOfficialClient:
|
||||||
def __init__(self, secret: str, token: str, app_id: str):
|
def __init__(self, secret: str, token: str, app_id: str):
|
||||||
self.app = Quart(__name__)
|
self.app = Quart(__name__)
|
||||||
self.app.add_url_rule(
|
self.app.add_url_rule(
|
||||||
"/callback/command",
|
'/callback/command',
|
||||||
"handle_callback",
|
'handle_callback',
|
||||||
self.handle_callback_request,
|
self.handle_callback_request,
|
||||||
methods=["GET", "POST"],
|
methods=['GET', 'POST'],
|
||||||
)
|
)
|
||||||
self.secret = secret
|
self.secret = secret
|
||||||
self.token = token
|
self.token = token
|
||||||
self.app_id = app_id
|
self.app_id = app_id
|
||||||
self._message_handlers = {
|
self._message_handlers = {}
|
||||||
}
|
self.base_url = 'https://api.sgroup.qq.com'
|
||||||
self.base_url = "https://api.sgroup.qq.com"
|
self.access_token = ''
|
||||||
self.access_token = ""
|
|
||||||
self.access_token_expiry_time = None
|
self.access_token_expiry_time = None
|
||||||
|
|
||||||
async def check_access_token(self):
|
async def check_access_token(self):
|
||||||
@@ -69,27 +58,26 @@ class QQOfficialClient:
|
|||||||
|
|
||||||
async def get_access_token(self):
|
async def get_access_token(self):
|
||||||
"""获取access_token"""
|
"""获取access_token"""
|
||||||
url = "https://bots.qq.com/app/getAppAccessToken"
|
url = 'https://bots.qq.com/app/getAppAccessToken'
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
params = {
|
params = {
|
||||||
"appId":self.app_id,
|
'appId': self.app_id,
|
||||||
"clientSecret":self.secret,
|
'clientSecret': self.secret,
|
||||||
}
|
}
|
||||||
headers = {
|
headers = {
|
||||||
"content-type":"application/json",
|
'content-type': 'application/json',
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
response = await client.post(url,json=params,headers=headers)
|
response = await client.post(url, json=params, headers=headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
access_token = response_data.get("access_token")
|
access_token = response_data.get('access_token')
|
||||||
expires_in = int(response_data.get("expires_in",7200))
|
expires_in = int(response_data.get('expires_in', 7200))
|
||||||
self.access_token_expiry_time = time.time() + expires_in - 60
|
self.access_token_expiry_time = time.time() + expires_in - 60
|
||||||
if access_token:
|
if access_token:
|
||||||
self.access_token = access_token
|
self.access_token = access_token
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"获取access_token失败: {e}")
|
raise Exception(f'获取access_token失败: {e}')
|
||||||
|
|
||||||
|
|
||||||
async def handle_callback_request(self):
|
async def handle_callback_request(self):
|
||||||
"""处理回调请求"""
|
"""处理回调请求"""
|
||||||
@@ -98,27 +86,24 @@ class QQOfficialClient:
|
|||||||
body = await request.get_data()
|
body = await request.get_data()
|
||||||
payload = json.loads(body)
|
payload = json.loads(body)
|
||||||
|
|
||||||
|
|
||||||
# 验证是否为回调验证请求
|
# 验证是否为回调验证请求
|
||||||
if payload.get("op") == 13:
|
if payload.get('op') == 13:
|
||||||
# 生成签名
|
# 生成签名
|
||||||
response = handle_validation(payload, self.secret)
|
response = handle_validation(payload, self.secret)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
if payload.get("op") == 0:
|
if payload.get('op') == 0:
|
||||||
message_data = await self.get_message(payload)
|
message_data = await self.get_message(payload)
|
||||||
if message_data:
|
if message_data:
|
||||||
event = QQOfficialEvent.from_payload(message_data)
|
event = QQOfficialEvent.from_payload(message_data)
|
||||||
await self._handle_message(event)
|
await self._handle_message(event)
|
||||||
|
|
||||||
return {"code": 0, "message": "success"}
|
return {'code': 0, 'message': 'success'}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return {"error": str(e)}, 400
|
return {'error': str(e)}, 400
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||||
"""启动 Quart 应用"""
|
"""启动 Quart 应用"""
|
||||||
@@ -135,133 +120,140 @@ class QQOfficialClient:
|
|||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
async def _handle_message(self, event:QQOfficialEvent):
|
async def _handle_message(self, event: QQOfficialEvent):
|
||||||
"""处理消息事件"""
|
"""处理消息事件"""
|
||||||
msg_type = event.t
|
msg_type = event.t
|
||||||
if msg_type in self._message_handlers:
|
if msg_type in self._message_handlers:
|
||||||
for handler in self._message_handlers[msg_type]:
|
for handler in self._message_handlers[msg_type]:
|
||||||
await handler(event)
|
await handler(event)
|
||||||
|
|
||||||
|
async def get_message(self, msg: dict) -> Dict[str, Any]:
|
||||||
async def get_message(self,msg:dict) -> Dict[str,Any]:
|
|
||||||
"""获取消息"""
|
"""获取消息"""
|
||||||
message_data = {
|
message_data = {
|
||||||
"t": msg.get("t",{}),
|
't': msg.get('t', {}),
|
||||||
"user_openid": msg.get("d",{}).get("author",{}).get("user_openid",{}),
|
'user_openid': msg.get('d', {}).get('author', {}).get('user_openid', {}),
|
||||||
"timestamp": msg.get("d",{}).get("timestamp",{}),
|
'timestamp': msg.get('d', {}).get('timestamp', {}),
|
||||||
"d_author_id": msg.get("d",{}).get("author",{}).get("id",{}),
|
'd_author_id': msg.get('d', {}).get('author', {}).get('id', {}),
|
||||||
"content": msg.get("d",{}).get("content",{}),
|
'content': msg.get('d', {}).get('content', {}),
|
||||||
"d_id": msg.get("d",{}).get("id",{}),
|
'd_id': msg.get('d', {}).get('id', {}),
|
||||||
"id": msg.get("id",{}),
|
'id': msg.get('id', {}),
|
||||||
"channel_id": msg.get("d",{}).get("channel_id",{}),
|
'channel_id': msg.get('d', {}).get('channel_id', {}),
|
||||||
"username": msg.get("d",{}).get("author",{}).get("username",{}),
|
'username': msg.get('d', {}).get('author', {}).get('username', {}),
|
||||||
"guild_id": msg.get("d",{}).get("guild_id",{}),
|
'guild_id': msg.get('d', {}).get('guild_id', {}),
|
||||||
"member_openid": msg.get("d",{}).get("author",{}).get("openid",{}),
|
'member_openid': msg.get('d', {}).get('author', {}).get('openid', {}),
|
||||||
"group_openid": msg.get("d",{}).get("group_openid",{})
|
'group_openid': msg.get('d', {}).get('group_openid', {}),
|
||||||
}
|
}
|
||||||
attachments = msg.get("d", {}).get("attachments", [])
|
attachments = msg.get('d', {}).get('attachments', [])
|
||||||
image_attachments = [attachment['url'] for attachment in attachments if await self.is_image(attachment)]
|
image_attachments = [
|
||||||
image_attachments_type = [attachment['content_type'] for attachment in attachments if await self.is_image(attachment)]
|
attachment['url']
|
||||||
|
for attachment in attachments
|
||||||
|
if await self.is_image(attachment)
|
||||||
|
]
|
||||||
|
image_attachments_type = [
|
||||||
|
attachment['content_type']
|
||||||
|
for attachment in attachments
|
||||||
|
if await self.is_image(attachment)
|
||||||
|
]
|
||||||
if image_attachments:
|
if image_attachments:
|
||||||
message_data["image_attachments"] = image_attachments[0]
|
message_data['image_attachments'] = image_attachments[0]
|
||||||
message_data["content_type"] = image_attachments_type[0]
|
message_data['content_type'] = image_attachments_type[0]
|
||||||
else:
|
else:
|
||||||
|
message_data['image_attachments'] = None
|
||||||
message_data["image_attachments"] = None
|
|
||||||
|
|
||||||
return message_data
|
return message_data
|
||||||
|
|
||||||
|
async def is_image(self, attachment: dict) -> bool:
|
||||||
async def is_image(self,attachment:dict) -> bool:
|
|
||||||
"""判断是否为图片附件"""
|
"""判断是否为图片附件"""
|
||||||
content_type = attachment.get("content_type","")
|
content_type = attachment.get('content_type', '')
|
||||||
return content_type.startswith("image/")
|
return content_type.startswith('image/')
|
||||||
|
|
||||||
|
async def send_private_text_msg(self, user_openid: str, content: str, msg_id: str):
|
||||||
async def send_private_text_msg(self,user_openid:str,content:str,msg_id:str):
|
|
||||||
"""发送私聊消息"""
|
"""发送私聊消息"""
|
||||||
if not await self.check_access_token():
|
if not await self.check_access_token():
|
||||||
await self.get_access_token()
|
await self.get_access_token()
|
||||||
|
|
||||||
url = self.base_url + "/v2/users/" + user_openid + "/messages"
|
url = self.base_url + '/v2/users/' + user_openid + '/messages'
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"QQBot {self.access_token}",
|
'Authorization': f'QQBot {self.access_token}',
|
||||||
"Content-Type": "application/json",
|
'Content-Type': 'application/json',
|
||||||
}
|
}
|
||||||
data = {
|
data = {
|
||||||
"content": content,
|
'content': content,
|
||||||
"msg_type": 0,
|
'msg_type': 0,
|
||||||
"msg_id": msg_id,
|
'msg_id': msg_id,
|
||||||
}
|
}
|
||||||
response = await client.post(url,headers=headers,json=data)
|
response = await client.post(url, headers=headers, json=data)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
raise ValueError(response)
|
raise ValueError(response)
|
||||||
|
|
||||||
|
async def send_group_text_msg(self, group_openid: str, content: str, msg_id: str):
|
||||||
async def send_group_text_msg(self,group_openid:str,content:str,msg_id:str):
|
|
||||||
"""发送群聊消息"""
|
"""发送群聊消息"""
|
||||||
if not await self.check_access_token():
|
if not await self.check_access_token():
|
||||||
await self.get_access_token()
|
await self.get_access_token()
|
||||||
|
|
||||||
url = self.base_url + "/v2/groups/" + group_openid + "/messages"
|
url = self.base_url + '/v2/groups/' + group_openid + '/messages'
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"QQBot {self.access_token}",
|
'Authorization': f'QQBot {self.access_token}',
|
||||||
"Content-Type": "application/json",
|
'Content-Type': 'application/json',
|
||||||
}
|
}
|
||||||
data = {
|
data = {
|
||||||
"content": content,
|
'content': content,
|
||||||
"msg_type": 0,
|
'msg_type': 0,
|
||||||
"msg_id": msg_id,
|
'msg_id': msg_id,
|
||||||
}
|
}
|
||||||
response = await client.post(url,headers=headers,json=data)
|
response = await client.post(url, headers=headers, json=data)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
raise Exception(response.read().decode())
|
raise Exception(response.read().decode())
|
||||||
|
|
||||||
async def send_channle_group_text_msg(self,channel_id:str,content:str,msg_id:str):
|
async def send_channle_group_text_msg(
|
||||||
|
self, channel_id: str, content: str, msg_id: str
|
||||||
|
):
|
||||||
"""发送频道群聊消息"""
|
"""发送频道群聊消息"""
|
||||||
if not await self.check_access_token():
|
if not await self.check_access_token():
|
||||||
await self.get_access_token()
|
await self.get_access_token()
|
||||||
|
|
||||||
url = self.base_url + "/channels/" + channel_id + "/messages"
|
url = self.base_url + '/channels/' + channel_id + '/messages'
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"QQBot {self.access_token}",
|
'Authorization': f'QQBot {self.access_token}',
|
||||||
"Content-Type": "application/json",
|
'Content-Type': 'application/json',
|
||||||
}
|
}
|
||||||
params = {
|
params = {
|
||||||
"content": content,
|
'content': content,
|
||||||
"msg_type": 0,
|
'msg_type': 0,
|
||||||
"msg_id": msg_id,
|
'msg_id': msg_id,
|
||||||
}
|
}
|
||||||
response = await client.post(url,headers=headers,json=params)
|
response = await client.post(url, headers=headers, json=params)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
raise Exception(response)
|
raise Exception(response)
|
||||||
|
|
||||||
async def send_channle_private_text_msg(self,guild_id:str,content:str,msg_id:str):
|
async def send_channle_private_text_msg(
|
||||||
|
self, guild_id: str, content: str, msg_id: str
|
||||||
|
):
|
||||||
"""发送频道私聊消息"""
|
"""发送频道私聊消息"""
|
||||||
if not await self.check_access_token():
|
if not await self.check_access_token():
|
||||||
await self.get_access_token()
|
await self.get_access_token()
|
||||||
|
|
||||||
url = self.base_url + "/dms/" + guild_id + "/messages"
|
url = self.base_url + '/dms/' + guild_id + '/messages'
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"QQBot {self.access_token}",
|
'Authorization': f'QQBot {self.access_token}',
|
||||||
"Content-Type": "application/json",
|
'Content-Type': 'application/json',
|
||||||
}
|
}
|
||||||
params = {
|
params = {
|
||||||
"content": content,
|
'content': content,
|
||||||
"msg_type": 0,
|
'msg_type': 0,
|
||||||
"msg_id": msg_id,
|
'msg_id': msg_id,
|
||||||
}
|
}
|
||||||
response = await client.post(url,headers=headers,json=params)
|
response = await client.post(url, headers=headers, json=params)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,101 +1,100 @@
|
|||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
|
|
||||||
class QQOfficialEvent(dict):
|
class QQOfficialEvent(dict):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_payload(payload: Dict[str, Any]) -> Optional["QQOfficialEvent"]:
|
def from_payload(payload: Dict[str, Any]) -> Optional['QQOfficialEvent']:
|
||||||
try:
|
try:
|
||||||
event = QQOfficialEvent(payload)
|
event = QQOfficialEvent(payload)
|
||||||
return event
|
return event
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def t(self) -> str:
|
def t(self) -> str:
|
||||||
"""
|
"""
|
||||||
事件类型
|
事件类型
|
||||||
"""
|
"""
|
||||||
return self.get("t", "")
|
return self.get('t', '')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def user_openid(self) -> str:
|
def user_openid(self) -> str:
|
||||||
"""
|
"""
|
||||||
用户openid
|
用户openid
|
||||||
"""
|
"""
|
||||||
return self.get("user_openid",{})
|
return self.get('user_openid', {})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def timestamp(self) -> str:
|
def timestamp(self) -> str:
|
||||||
"""
|
"""
|
||||||
时间戳
|
时间戳
|
||||||
"""
|
"""
|
||||||
return self.get("timestamp",{})
|
return self.get('timestamp', {})
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def d_author_id(self) -> str:
|
def d_author_id(self) -> str:
|
||||||
"""
|
"""
|
||||||
作者id
|
作者id
|
||||||
"""
|
"""
|
||||||
return self.get("id",{})
|
return self.get('id', {})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def content(self) -> str:
|
def content(self) -> str:
|
||||||
"""
|
"""
|
||||||
内容
|
内容
|
||||||
"""
|
"""
|
||||||
return self.get("content",'')
|
return self.get('content', '')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def d_id(self) -> str:
|
def d_id(self) -> str:
|
||||||
"""
|
"""
|
||||||
d_id
|
d_id
|
||||||
"""
|
"""
|
||||||
return self.get("d_id",{})
|
return self.get('d_id', {})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def id(self) -> str:
|
def id(self) -> str:
|
||||||
"""
|
"""
|
||||||
消息id,msg_id
|
消息id,msg_id
|
||||||
"""
|
"""
|
||||||
return self.get("id",{})
|
return self.get('id', {})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def channel_id(self) -> str:
|
def channel_id(self) -> str:
|
||||||
"""
|
"""
|
||||||
频道id
|
频道id
|
||||||
"""
|
"""
|
||||||
return self.get("channel_id",{})
|
return self.get('channel_id', {})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def username(self) -> str:
|
def username(self) -> str:
|
||||||
"""
|
"""
|
||||||
用户名
|
用户名
|
||||||
"""
|
"""
|
||||||
return self.get("username",{})
|
return self.get('username', {})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def guild_id(self) -> str:
|
def guild_id(self) -> str:
|
||||||
"""
|
"""
|
||||||
频道id
|
频道id
|
||||||
"""
|
"""
|
||||||
return self.get("guild_id",{})
|
return self.get('guild_id', {})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def member_openid(self) -> str:
|
def member_openid(self) -> str:
|
||||||
"""
|
"""
|
||||||
成员openid
|
成员openid
|
||||||
"""
|
"""
|
||||||
return self.get("openid",{})
|
return self.get('openid', {})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def attachments(self) -> str:
|
def attachments(self) -> str:
|
||||||
"""
|
"""
|
||||||
附件url
|
附件url
|
||||||
"""
|
"""
|
||||||
url = self.get("image_attachments", "")
|
url = self.get('image_attachments', '')
|
||||||
if url and not url.startswith("https://"):
|
if url and not url.startswith('https://'):
|
||||||
url = "https://" + url
|
url = 'https://' + url
|
||||||
return url
|
return url
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -103,12 +102,11 @@ class QQOfficialEvent(dict):
|
|||||||
"""
|
"""
|
||||||
群组id
|
群组id
|
||||||
"""
|
"""
|
||||||
return self.get("group_openid",{})
|
return self.get('group_openid', {})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def content_type(self) -> str:
|
def content_type(self) -> str:
|
||||||
"""
|
"""
|
||||||
文件类型
|
文件类型
|
||||||
"""
|
"""
|
||||||
return self.get("content_type","")
|
return self.get('content_type', '')
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding:utf-8 -*-
|
# -*- encoding:utf-8 -*-
|
||||||
|
|
||||||
""" 对企业微信发送给企业后台的消息加解密示例代码.
|
"""对企业微信发送给企业后台的消息加解密示例代码.
|
||||||
@copyright: Copyright (c) 1998-2014 Tencent Inc.
|
@copyright: Copyright (c) 1998-2014 Tencent Inc.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ------------------------------------------------------------------------
|
# ------------------------------------------------------------------------
|
||||||
import logging
|
import logging
|
||||||
import base64
|
import base64
|
||||||
@@ -49,7 +50,7 @@ class SHA1:
|
|||||||
sortlist = [token, timestamp, nonce, encrypt]
|
sortlist = [token, timestamp, nonce, encrypt]
|
||||||
sortlist.sort()
|
sortlist.sort()
|
||||||
sha = hashlib.sha1()
|
sha = hashlib.sha1()
|
||||||
sha.update("".join(sortlist).encode())
|
sha.update(''.join(sortlist).encode())
|
||||||
return ierror.WXBizMsgCrypt_OK, sha.hexdigest()
|
return ierror.WXBizMsgCrypt_OK, sha.hexdigest()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
@@ -75,7 +76,7 @@ class XMLParse:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
xml_tree = ET.fromstring(xmltext)
|
xml_tree = ET.fromstring(xmltext)
|
||||||
encrypt = xml_tree.find("Encrypt")
|
encrypt = xml_tree.find('Encrypt')
|
||||||
return ierror.WXBizMsgCrypt_OK, encrypt.text
|
return ierror.WXBizMsgCrypt_OK, encrypt.text
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
@@ -100,13 +101,13 @@ class XMLParse:
|
|||||||
return resp_xml
|
return resp_xml
|
||||||
|
|
||||||
|
|
||||||
class PKCS7Encoder():
|
class PKCS7Encoder:
|
||||||
"""提供基于PKCS7算法的加解密接口"""
|
"""提供基于PKCS7算法的加解密接口"""
|
||||||
|
|
||||||
block_size = 32
|
block_size = 32
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
""" 对需要加密的明文进行填充补位
|
"""对需要加密的明文进行填充补位
|
||||||
@param text: 需要进行填充补位操作的明文
|
@param text: 需要进行填充补位操作的明文
|
||||||
@return: 补齐明文字符串
|
@return: 补齐明文字符串
|
||||||
"""
|
"""
|
||||||
@@ -134,7 +135,6 @@ class Prpcrypt(object):
|
|||||||
"""提供接收和推送给企业微信消息的加解密接口"""
|
"""提供接收和推送给企业微信消息的加解密接口"""
|
||||||
|
|
||||||
def __init__(self, key):
|
def __init__(self, key):
|
||||||
|
|
||||||
# self.key = base64.b64decode(key+"=")
|
# self.key = base64.b64decode(key+"=")
|
||||||
self.key = key
|
self.key = key
|
||||||
# 设置加解密模式为AES的CBC模式
|
# 设置加解密模式为AES的CBC模式
|
||||||
@@ -147,7 +147,12 @@ class Prpcrypt(object):
|
|||||||
"""
|
"""
|
||||||
# 16位随机字符串添加到明文开头
|
# 16位随机字符串添加到明文开头
|
||||||
text = text.encode()
|
text = text.encode()
|
||||||
text = self.get_random_str() + struct.pack("I", socket.htonl(len(text))) + text + receiveid.encode()
|
text = (
|
||||||
|
self.get_random_str()
|
||||||
|
+ struct.pack('I', socket.htonl(len(text)))
|
||||||
|
+ text
|
||||||
|
+ receiveid.encode()
|
||||||
|
)
|
||||||
|
|
||||||
# 使用自定义的填充方式对明文进行补位填充
|
# 使用自定义的填充方式对明文进行补位填充
|
||||||
pkcs7 = PKCS7Encoder()
|
pkcs7 = PKCS7Encoder()
|
||||||
@@ -183,9 +188,9 @@ class Prpcrypt(object):
|
|||||||
# plain_text = pkcs7.encode(plain_text)
|
# plain_text = pkcs7.encode(plain_text)
|
||||||
# 去除16位随机字符串
|
# 去除16位随机字符串
|
||||||
content = plain_text[16:-pad]
|
content = plain_text[16:-pad]
|
||||||
xml_len = socket.ntohl(struct.unpack("I", content[: 4])[0])
|
xml_len = socket.ntohl(struct.unpack('I', content[:4])[0])
|
||||||
xml_content = content[4: xml_len + 4]
|
xml_content = content[4 : xml_len + 4]
|
||||||
from_receiveid = content[xml_len + 4:]
|
from_receiveid = content[xml_len + 4 :]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
@@ -196,7 +201,7 @@ class Prpcrypt(object):
|
|||||||
return 0, xml_content
|
return 0, xml_content
|
||||||
|
|
||||||
def get_random_str(self):
|
def get_random_str(self):
|
||||||
""" 随机生成16位字符串
|
"""随机生成16位字符串
|
||||||
@return: 16位字符串
|
@return: 16位字符串
|
||||||
"""
|
"""
|
||||||
return str(random.randint(1000000000000000, 9999999999999999)).encode()
|
return str(random.randint(1000000000000000, 9999999999999999)).encode()
|
||||||
@@ -206,10 +211,10 @@ class WXBizMsgCrypt(object):
|
|||||||
# 构造函数
|
# 构造函数
|
||||||
def __init__(self, sToken, sEncodingAESKey, sReceiveId):
|
def __init__(self, sToken, sEncodingAESKey, sReceiveId):
|
||||||
try:
|
try:
|
||||||
self.key = base64.b64decode(sEncodingAESKey + "=")
|
self.key = base64.b64decode(sEncodingAESKey + '=')
|
||||||
assert len(self.key) == 32
|
assert len(self.key) == 32
|
||||||
except:
|
except Exception:
|
||||||
throw_exception("[error]: EncodingAESKey unvalid !", FormatException)
|
throw_exception('[error]: EncodingAESKey unvalid !', FormatException)
|
||||||
# return ierror.WXBizMsgCrypt_IllegalAesKey,None
|
# return ierror.WXBizMsgCrypt_IllegalAesKey,None
|
||||||
self.m_sToken = sToken
|
self.m_sToken = sToken
|
||||||
self.m_sReceiveId = sReceiveId
|
self.m_sReceiveId = sReceiveId
|
||||||
|
|||||||
+158
-110
@@ -7,15 +7,22 @@ from quart import Quart
|
|||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from typing import Callable, Dict, Any
|
from typing import Callable, Dict, Any
|
||||||
from .wecomevent import WecomEvent
|
from .wecomevent import WecomEvent
|
||||||
from pkg.platform.types import events as platform_events, message as platform_message
|
from pkg.platform.types import message as platform_message
|
||||||
import aiofiles
|
import aiofiles
|
||||||
|
|
||||||
|
|
||||||
class WecomClient():
|
class WecomClient:
|
||||||
def __init__(self,corpid:str,secret:str,token:str,EncodingAESKey:str,contacts_secret:str):
|
def __init__(
|
||||||
|
self,
|
||||||
|
corpid: str,
|
||||||
|
secret: str,
|
||||||
|
token: str,
|
||||||
|
EncodingAESKey: str,
|
||||||
|
contacts_secret: str,
|
||||||
|
):
|
||||||
self.corpid = corpid
|
self.corpid = corpid
|
||||||
self.secret = secret
|
self.secret = secret
|
||||||
self.access_token_for_contacts =''
|
self.access_token_for_contacts = ''
|
||||||
self.token = token
|
self.token = token
|
||||||
self.aes = EncodingAESKey
|
self.aes = EncodingAESKey
|
||||||
self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin'
|
self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin'
|
||||||
@@ -23,19 +30,26 @@ class WecomClient():
|
|||||||
self.secret_for_contacts = contacts_secret
|
self.secret_for_contacts = contacts_secret
|
||||||
self.app = Quart(__name__)
|
self.app = Quart(__name__)
|
||||||
self.wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid)
|
self.wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid)
|
||||||
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'])
|
self.app.add_url_rule(
|
||||||
|
'/callback/command',
|
||||||
|
'handle_callback',
|
||||||
|
self.handle_callback_request,
|
||||||
|
methods=['GET', 'POST'],
|
||||||
|
)
|
||||||
self._message_handlers = {
|
self._message_handlers = {
|
||||||
"example":[],
|
'example': [],
|
||||||
}
|
}
|
||||||
|
|
||||||
#access——token操作
|
# access——token操作
|
||||||
async def check_access_token(self):
|
async def check_access_token(self):
|
||||||
return bool(self.access_token and self.access_token.strip())
|
return bool(self.access_token and self.access_token.strip())
|
||||||
|
|
||||||
async def check_access_token_for_contacts(self):
|
async def check_access_token_for_contacts(self):
|
||||||
return bool(self.access_token_for_contacts and self.access_token_for_contacts.strip())
|
return bool(
|
||||||
|
self.access_token_for_contacts and self.access_token_for_contacts.strip()
|
||||||
|
)
|
||||||
|
|
||||||
async def get_access_token(self,secret):
|
async def get_access_token(self, secret):
|
||||||
url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corpid}&corpsecret={secret}'
|
url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corpid}&corpsecret={secret}'
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.get(url)
|
response = await client.get(url)
|
||||||
@@ -43,146 +57,163 @@ class WecomClient():
|
|||||||
if 'access_token' in data:
|
if 'access_token' in data:
|
||||||
return data['access_token']
|
return data['access_token']
|
||||||
else:
|
else:
|
||||||
raise Exception(f"未获取access token: {data}")
|
raise Exception(f'未获取access token: {data}')
|
||||||
|
|
||||||
async def get_users(self):
|
async def get_users(self):
|
||||||
if not self.check_access_token_for_contacts():
|
if not self.check_access_token_for_contacts():
|
||||||
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
|
self.access_token_for_contacts = await self.get_access_token(
|
||||||
|
self.secret_for_contacts
|
||||||
|
)
|
||||||
|
|
||||||
url = self.base_url+'/user/list_id?access_token='+self.access_token_for_contacts
|
url = (
|
||||||
|
self.base_url
|
||||||
|
+ '/user/list_id?access_token='
|
||||||
|
+ self.access_token_for_contacts
|
||||||
|
)
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
params = {
|
params = {
|
||||||
"cursor":"",
|
'cursor': '',
|
||||||
"limit":10000,
|
'limit': 10000,
|
||||||
}
|
}
|
||||||
response = await client.post(url,json=params)
|
response = await client.post(url, json=params)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
if data['errcode'] == 0:
|
if data['errcode'] == 0:
|
||||||
dept_users = data['dept_user']
|
dept_users = data['dept_user']
|
||||||
userid = []
|
userid = []
|
||||||
for user in dept_users:
|
for user in dept_users:
|
||||||
userid.append(user["userid"])
|
userid.append(user['userid'])
|
||||||
return userid
|
return userid
|
||||||
else:
|
else:
|
||||||
raise Exception("未获取用户")
|
raise Exception('未获取用户')
|
||||||
|
|
||||||
async def send_to_all(self,content:str,agent_id:int):
|
async def send_to_all(self, content: str, agent_id: int):
|
||||||
if not self.check_access_token_for_contacts():
|
if not self.check_access_token_for_contacts():
|
||||||
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
|
self.access_token_for_contacts = await self.get_access_token(
|
||||||
|
self.secret_for_contacts
|
||||||
|
)
|
||||||
|
|
||||||
url = self.base_url+'/message/send?access_token='+self.access_token_for_contacts
|
url = (
|
||||||
|
self.base_url
|
||||||
|
+ '/message/send?access_token='
|
||||||
|
+ self.access_token_for_contacts
|
||||||
|
)
|
||||||
user_ids = await self.get_users()
|
user_ids = await self.get_users()
|
||||||
user_ids_string = "|".join(user_ids)
|
user_ids_string = '|'.join(user_ids)
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
params = {
|
params = {
|
||||||
"touser" : user_ids_string,
|
'touser': user_ids_string,
|
||||||
"msgtype" : "text",
|
'msgtype': 'text',
|
||||||
"agentid" : agent_id,
|
'agentid': agent_id,
|
||||||
"text" : {
|
'text': {
|
||||||
"content" : content,
|
'content': content,
|
||||||
},
|
},
|
||||||
"safe":0,
|
'safe': 0,
|
||||||
"enable_id_trans": 0,
|
'enable_id_trans': 0,
|
||||||
"enable_duplicate_check": 0,
|
'enable_duplicate_check': 0,
|
||||||
"duplicate_check_interval": 1800
|
'duplicate_check_interval': 1800,
|
||||||
}
|
}
|
||||||
response = await client.post(url,json=params)
|
response = await client.post(url, json=params)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
if data['errcode'] != 0:
|
if data['errcode'] != 0:
|
||||||
raise Exception("Failed to send message: "+str(data))
|
raise Exception('Failed to send message: ' + str(data))
|
||||||
|
|
||||||
async def send_image(self,user_id:str,agent_id:int,media_id:str):
|
async def send_image(self, user_id: str, agent_id: int, media_id: str):
|
||||||
if not await self.check_access_token():
|
if not await self.check_access_token():
|
||||||
self.access_token = await self.get_access_token(self.secret)
|
self.access_token = await self.get_access_token(self.secret)
|
||||||
url = self.base_url+'/media/upload?access_token='+self.access_token
|
url = self.base_url + '/media/upload?access_token=' + self.access_token
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
params = {
|
params = {
|
||||||
"touser" : user_id,
|
'touser': user_id,
|
||||||
"toparty" : "",
|
'toparty': '',
|
||||||
"totag":"",
|
'totag': '',
|
||||||
"agentid" : agent_id,
|
'agentid': agent_id,
|
||||||
"msgtype" : "image",
|
'msgtype': 'image',
|
||||||
"image" : {
|
'image': {
|
||||||
"media_id" : media_id,
|
'media_id': media_id,
|
||||||
},
|
},
|
||||||
"safe":0,
|
'safe': 0,
|
||||||
"enable_id_trans": 0,
|
'enable_id_trans': 0,
|
||||||
"enable_duplicate_check": 0,
|
'enable_duplicate_check': 0,
|
||||||
"duplicate_check_interval": 1800
|
'duplicate_check_interval': 1800,
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
response = await client.post(url,json=params)
|
response = await client.post(url, json=params)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception("Failed to send image: "+str(e))
|
raise Exception('Failed to send image: ' + str(e))
|
||||||
|
|
||||||
# 企业微信错误码40014和42001,代表accesstoken问题
|
# 企业微信错误码40014和42001,代表accesstoken问题
|
||||||
if data['errcode'] == 40014 or data['errcode'] == 42001:
|
if data['errcode'] == 40014 or data['errcode'] == 42001:
|
||||||
self.access_token = await self.get_access_token(self.secret)
|
self.access_token = await self.get_access_token(self.secret)
|
||||||
return await self.send_image(user_id,agent_id,media_id)
|
return await self.send_image(user_id, agent_id, media_id)
|
||||||
|
|
||||||
if data['errcode'] != 0:
|
if data['errcode'] != 0:
|
||||||
raise Exception("Failed to send image: "+str(data))
|
raise Exception('Failed to send image: ' + str(data))
|
||||||
|
|
||||||
async def send_private_msg(self,user_id:str, agent_id:int,content:str):
|
async def send_private_msg(self, user_id: str, agent_id: int, content: str):
|
||||||
if not await self.check_access_token():
|
if not await self.check_access_token():
|
||||||
self.access_token = await self.get_access_token(self.secret)
|
self.access_token = await self.get_access_token(self.secret)
|
||||||
|
|
||||||
url = self.base_url+'/message/send?access_token='+self.access_token
|
url = self.base_url + '/message/send?access_token=' + self.access_token
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
params={
|
params = {
|
||||||
"touser" : user_id,
|
'touser': user_id,
|
||||||
"msgtype" : "text",
|
'msgtype': 'text',
|
||||||
"agentid" : agent_id,
|
'agentid': agent_id,
|
||||||
"text" : {
|
'text': {
|
||||||
"content" : content,
|
'content': content,
|
||||||
},
|
},
|
||||||
"safe":0,
|
'safe': 0,
|
||||||
"enable_id_trans": 0,
|
'enable_id_trans': 0,
|
||||||
"enable_duplicate_check": 0,
|
'enable_duplicate_check': 0,
|
||||||
"duplicate_check_interval": 1800
|
'duplicate_check_interval': 1800,
|
||||||
}
|
}
|
||||||
response = await client.post(url,json=params)
|
response = await client.post(url, json=params)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
if data['errcode'] == 40014 or data['errcode'] == 42001:
|
if data['errcode'] == 40014 or data['errcode'] == 42001:
|
||||||
self.access_token = await self.get_access_token(self.secret)
|
self.access_token = await self.get_access_token(self.secret)
|
||||||
return await self.send_private_msg(user_id,agent_id,content)
|
return await self.send_private_msg(user_id, agent_id, content)
|
||||||
if data['errcode'] != 0:
|
if data['errcode'] != 0:
|
||||||
raise Exception("Failed to send message: "+str(data))
|
raise Exception('Failed to send message: ' + str(data))
|
||||||
|
|
||||||
async def handle_callback_request(self):
|
async def handle_callback_request(self):
|
||||||
"""
|
"""
|
||||||
处理回调请求,包括 GET 验证和 POST 消息接收。
|
处理回调请求,包括 GET 验证和 POST 消息接收。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
msg_signature = request.args.get('msg_signature')
|
||||||
|
timestamp = request.args.get('timestamp')
|
||||||
|
nonce = request.args.get('nonce')
|
||||||
|
|
||||||
msg_signature = request.args.get("msg_signature")
|
if request.method == 'GET':
|
||||||
timestamp = request.args.get("timestamp")
|
echostr = request.args.get('echostr')
|
||||||
nonce = request.args.get("nonce")
|
ret, reply_echo_str = self.wxcpt.VerifyURL(
|
||||||
|
msg_signature, timestamp, nonce, echostr
|
||||||
if request.method == "GET":
|
)
|
||||||
echostr = request.args.get("echostr")
|
|
||||||
ret, reply_echo_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
|
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
raise Exception(f"验证失败,错误码: {ret}")
|
raise Exception(f'验证失败,错误码: {ret}')
|
||||||
return reply_echo_str
|
return reply_echo_str
|
||||||
|
|
||||||
elif request.method == "POST":
|
elif request.method == 'POST':
|
||||||
encrypt_msg = await request.data
|
encrypt_msg = await request.data
|
||||||
ret, xml_msg = self.wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
|
ret, xml_msg = self.wxcpt.DecryptMsg(
|
||||||
|
encrypt_msg, msg_signature, timestamp, nonce
|
||||||
|
)
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
raise Exception(f"消息解密失败,错误码: {ret}")
|
raise Exception(f'消息解密失败,错误码: {ret}')
|
||||||
|
|
||||||
# 解析消息并处理
|
# 解析消息并处理
|
||||||
message_data = await self.get_message(xml_msg)
|
message_data = await self.get_message(xml_msg)
|
||||||
if message_data:
|
if message_data:
|
||||||
event = WecomEvent.from_payload(message_data) # 转换为 WecomEvent 对象
|
event = WecomEvent.from_payload(
|
||||||
|
message_data
|
||||||
|
) # 转换为 WecomEvent 对象
|
||||||
if event:
|
if event:
|
||||||
await self._handle_message(event)
|
await self._handle_message(event)
|
||||||
|
|
||||||
return "success"
|
return 'success'
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error processing request: {str(e)}", 400
|
return f'Error processing request: {str(e)}', 400
|
||||||
|
|
||||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -194,11 +225,13 @@ class WecomClient():
|
|||||||
"""
|
"""
|
||||||
注册消息类型处理器。
|
注册消息类型处理器。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[[WecomEvent], None]):
|
def decorator(func: Callable[[WecomEvent], None]):
|
||||||
if msg_type not in self._message_handlers:
|
if msg_type not in self._message_handlers:
|
||||||
self._message_handlers[msg_type] = []
|
self._message_handlers[msg_type] = []
|
||||||
self._message_handlers[msg_type].append(func)
|
self._message_handlers[msg_type].append(func)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
async def _handle_message(self, event: WecomEvent):
|
async def _handle_message(self, event: WecomEvent):
|
||||||
@@ -216,17 +249,27 @@ class WecomClient():
|
|||||||
"""
|
"""
|
||||||
root = ET.fromstring(xml_msg)
|
root = ET.fromstring(xml_msg)
|
||||||
message_data = {
|
message_data = {
|
||||||
"ToUserName": root.find("ToUserName").text,
|
'ToUserName': root.find('ToUserName').text,
|
||||||
"FromUserName": root.find("FromUserName").text,
|
'FromUserName': root.find('FromUserName').text,
|
||||||
"CreateTime": int(root.find("CreateTime").text),
|
'CreateTime': int(root.find('CreateTime').text),
|
||||||
"MsgType": root.find("MsgType").text,
|
'MsgType': root.find('MsgType').text,
|
||||||
"Content": root.find("Content").text if root.find("Content") is not None else None,
|
'Content': root.find('Content').text
|
||||||
"MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None,
|
if root.find('Content') is not None
|
||||||
"AgentID": int(root.find("AgentID").text) if root.find("AgentID") is not None else None,
|
else None,
|
||||||
|
'MsgId': int(root.find('MsgId').text)
|
||||||
|
if root.find('MsgId') is not None
|
||||||
|
else None,
|
||||||
|
'AgentID': int(root.find('AgentID').text)
|
||||||
|
if root.find('AgentID') is not None
|
||||||
|
else None,
|
||||||
}
|
}
|
||||||
if message_data["MsgType"] == "image":
|
if message_data['MsgType'] == 'image':
|
||||||
message_data["MediaId"] = root.find("MediaId").text if root.find("MediaId") is not None else None
|
message_data['MediaId'] = (
|
||||||
message_data["PicUrl"] = root.find("PicUrl").text if root.find("PicUrl") is not None else None
|
root.find('MediaId').text if root.find('MediaId') is not None else None
|
||||||
|
)
|
||||||
|
message_data['PicUrl'] = (
|
||||||
|
root.find('PicUrl').text if root.find('PicUrl') is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
return message_data
|
return message_data
|
||||||
|
|
||||||
@@ -236,11 +279,11 @@ class WecomClient():
|
|||||||
通过图片的magic numbers判断图片类型
|
通过图片的magic numbers判断图片类型
|
||||||
"""
|
"""
|
||||||
magic_numbers = {
|
magic_numbers = {
|
||||||
b'\xFF\xD8\xFF': 'jpg',
|
b'\xff\xd8\xff': 'jpg',
|
||||||
b'\x89\x50\x4E\x47': 'png',
|
b'\x89\x50\x4e\x47': 'png',
|
||||||
b'\x47\x49\x46': 'gif',
|
b'\x47\x49\x46': 'gif',
|
||||||
b'\x42\x4D': 'bmp',
|
b'\x42\x4d': 'bmp',
|
||||||
b'\x00\x00\x01\x00': 'ico'
|
b'\x00\x00\x01\x00': 'ico',
|
||||||
}
|
}
|
||||||
|
|
||||||
for magic, ext in magic_numbers.items():
|
for magic, ext in magic_numbers.items():
|
||||||
@@ -248,7 +291,6 @@ class WecomClient():
|
|||||||
return ext
|
return ext
|
||||||
return 'jpg' # 默认返回jpg
|
return 'jpg' # 默认返回jpg
|
||||||
|
|
||||||
|
|
||||||
async def upload_to_work(self, image: platform_message.Image):
|
async def upload_to_work(self, image: platform_message.Image):
|
||||||
"""
|
"""
|
||||||
获取 media_id
|
获取 media_id
|
||||||
@@ -256,9 +298,14 @@ class WecomClient():
|
|||||||
if not await self.check_access_token():
|
if not await self.check_access_token():
|
||||||
self.access_token = await self.get_access_token(self.secret)
|
self.access_token = await self.get_access_token(self.secret)
|
||||||
|
|
||||||
url = self.base_url + '/media/upload?access_token=' + self.access_token + '&type=file'
|
url = (
|
||||||
|
self.base_url
|
||||||
|
+ '/media/upload?access_token='
|
||||||
|
+ self.access_token
|
||||||
|
+ '&type=file'
|
||||||
|
)
|
||||||
file_bytes = None
|
file_bytes = None
|
||||||
file_name = "uploaded_file.txt"
|
file_name = 'uploaded_file.txt'
|
||||||
|
|
||||||
# 获取文件的二进制数据
|
# 获取文件的二进制数据
|
||||||
if image.path:
|
if image.path:
|
||||||
@@ -277,20 +324,22 @@ class WecomClient():
|
|||||||
padded_base64 = base64_data + '=' * padding
|
padded_base64 = base64_data + '=' * padding
|
||||||
file_bytes = base64.b64decode(padded_base64)
|
file_bytes = base64.b64decode(padded_base64)
|
||||||
except binascii.Error as e:
|
except binascii.Error as e:
|
||||||
raise ValueError(f"Invalid base64 string: {str(e)}")
|
raise ValueError(f'Invalid base64 string: {str(e)}')
|
||||||
else:
|
else:
|
||||||
raise ValueError("image对象出错")
|
raise ValueError('image对象出错')
|
||||||
|
|
||||||
# 设置 multipart/form-data 格式的文件
|
# 设置 multipart/form-data 格式的文件
|
||||||
boundary = "-------------------------acebdf13572468"
|
boundary = '-------------------------acebdf13572468'
|
||||||
headers = {
|
headers = {'Content-Type': f'multipart/form-data; boundary={boundary}'}
|
||||||
'Content-Type': f'multipart/form-data; boundary={boundary}'
|
|
||||||
}
|
|
||||||
body = (
|
body = (
|
||||||
f"--{boundary}\r\n"
|
(
|
||||||
f"Content-Disposition: form-data; name=\"media\"; filename=\"{file_name}\"; filelength={len(file_bytes)}\r\n"
|
f'--{boundary}\r\n'
|
||||||
f"Content-Type: application/octet-stream\r\n\r\n"
|
f'Content-Disposition: form-data; name="media"; filename="{file_name}"; filelength={len(file_bytes)}\r\n'
|
||||||
).encode('utf-8') + file_bytes + f"\r\n--{boundary}--\r\n".encode('utf-8')
|
f'Content-Type: application/octet-stream\r\n\r\n'
|
||||||
|
).encode('utf-8')
|
||||||
|
+ file_bytes
|
||||||
|
+ f'\r\n--{boundary}--\r\n'.encode('utf-8')
|
||||||
|
)
|
||||||
|
|
||||||
# 上传文件
|
# 上传文件
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
@@ -300,19 +349,18 @@ class WecomClient():
|
|||||||
self.access_token = await self.get_access_token(self.secret)
|
self.access_token = await self.get_access_token(self.secret)
|
||||||
media_id = await self.upload_to_work(image)
|
media_id = await self.upload_to_work(image)
|
||||||
if data.get('errcode', 0) != 0:
|
if data.get('errcode', 0) != 0:
|
||||||
raise Exception("failed to upload file")
|
raise Exception('failed to upload file')
|
||||||
|
|
||||||
media_id = data.get('media_id')
|
media_id = data.get('media_id')
|
||||||
return media_id
|
return media_id
|
||||||
|
|
||||||
async def download_image_to_bytes(self,url:str) -> bytes:
|
async def download_image_to_bytes(self, url: str) -> bytes:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.get(url)
|
response = await client.get(url)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.content
|
return response.content
|
||||||
|
|
||||||
#进行media_id的获取
|
# 进行media_id的获取
|
||||||
async def get_media_id(self, image: platform_message.Image):
|
async def get_media_id(self, image: platform_message.Image):
|
||||||
|
|
||||||
media_id = await self.upload_to_work(image=image)
|
media_id = await self.upload_to_work(image=image)
|
||||||
return media_id
|
return media_id
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ class WecomEvent(dict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_payload(payload: Dict[str, Any]) -> Optional["WecomEvent"]:
|
def from_payload(payload: Dict[str, Any]) -> Optional['WecomEvent']:
|
||||||
"""
|
"""
|
||||||
从企业微信事件数据构造 `WecomEvent` 对象。
|
从企业微信事件数据构造 `WecomEvent` 对象。
|
||||||
|
|
||||||
@@ -34,14 +34,14 @@ class WecomEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
str: 事件类型。
|
str: 事件类型。
|
||||||
"""
|
"""
|
||||||
return self.get("MsgType", "")
|
return self.get('MsgType', '')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def picurl(self) -> str:
|
def picurl(self) -> str:
|
||||||
"""
|
"""
|
||||||
图片链接
|
图片链接
|
||||||
"""
|
"""
|
||||||
return self.get("PicUrl")
|
return self.get('PicUrl')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def detail_type(self) -> str:
|
def detail_type(self) -> str:
|
||||||
@@ -53,8 +53,8 @@ class WecomEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
str: 事件详细类型。
|
str: 事件详细类型。
|
||||||
"""
|
"""
|
||||||
if self.type == "event":
|
if self.type == 'event':
|
||||||
return self.get("Event", "")
|
return self.get('Event', '')
|
||||||
return self.type
|
return self.type
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -65,7 +65,7 @@ class WecomEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
str: 事件名。
|
str: 事件名。
|
||||||
"""
|
"""
|
||||||
return f"{self.type}.{self.detail_type}"
|
return f'{self.type}.{self.detail_type}'
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def user_id(self) -> Optional[str]:
|
def user_id(self) -> Optional[str]:
|
||||||
@@ -75,7 +75,7 @@ class WecomEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[str]: 用户 ID。
|
Optional[str]: 用户 ID。
|
||||||
"""
|
"""
|
||||||
return self.get("FromUserName")
|
return self.get('FromUserName')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def agent_id(self) -> Optional[int]:
|
def agent_id(self) -> Optional[int]:
|
||||||
@@ -85,7 +85,7 @@ class WecomEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[int]: 机器人 ID。
|
Optional[int]: 机器人 ID。
|
||||||
"""
|
"""
|
||||||
return self.get("AgentID")
|
return self.get('AgentID')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def receiver_id(self) -> Optional[str]:
|
def receiver_id(self) -> Optional[str]:
|
||||||
@@ -95,7 +95,7 @@ class WecomEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[str]: 接收者 ID。
|
Optional[str]: 接收者 ID。
|
||||||
"""
|
"""
|
||||||
return self.get("ToUserName")
|
return self.get('ToUserName')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def message_id(self) -> Optional[str]:
|
def message_id(self) -> Optional[str]:
|
||||||
@@ -105,7 +105,7 @@ class WecomEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[str]: 消息 ID。
|
Optional[str]: 消息 ID。
|
||||||
"""
|
"""
|
||||||
return self.get("MsgId")
|
return self.get('MsgId')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def message(self) -> Optional[str]:
|
def message(self) -> Optional[str]:
|
||||||
@@ -115,7 +115,7 @@ class WecomEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[str]: 消息内容。
|
Optional[str]: 消息内容。
|
||||||
"""
|
"""
|
||||||
return self.get("Content")
|
return self.get('Content')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def media_id(self) -> Optional[str]:
|
def media_id(self) -> Optional[str]:
|
||||||
@@ -125,7 +125,7 @@ class WecomEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[str]: 媒体文件 ID。
|
Optional[str]: 媒体文件 ID。
|
||||||
"""
|
"""
|
||||||
return self.get("MediaId")
|
return self.get('MediaId')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def timestamp(self) -> Optional[int]:
|
def timestamp(self) -> Optional[int]:
|
||||||
@@ -135,7 +135,7 @@ class WecomEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[int]: 时间戳。
|
Optional[int]: 时间戳。
|
||||||
"""
|
"""
|
||||||
return self.get("CreateTime")
|
return self.get('CreateTime')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def event_key(self) -> Optional[str]:
|
def event_key(self) -> Optional[str]:
|
||||||
@@ -145,7 +145,7 @@ class WecomEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[str]: 事件 Key。
|
Optional[str]: 事件 Key。
|
||||||
"""
|
"""
|
||||||
return self.get("EventKey")
|
return self.get('EventKey')
|
||||||
|
|
||||||
def __getattr__(self, key: str) -> Optional[Any]:
|
def __getattr__(self, key: str) -> Optional[Any]:
|
||||||
"""
|
"""
|
||||||
@@ -176,4 +176,4 @@ class WecomEvent(dict):
|
|||||||
Returns:
|
Returns:
|
||||||
str: 字符串表示。
|
str: 字符串表示。
|
||||||
"""
|
"""
|
||||||
return f"<WecomEvent {super().__repr__()}>"
|
return f'<WecomEvent {super().__repr__()}>'
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
# LangBot 终端启动入口
|
# LangBot 终端启动入口
|
||||||
# 在此层级解决依赖项检查。
|
# 在此层级解决依赖项检查。
|
||||||
# LangBot/main.py
|
# LangBot/main.py
|
||||||
@@ -14,9 +15,6 @@ asciiart = r"""
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
|
|
||||||
async def main_entry(loop: asyncio.AbstractEventLoop):
|
async def main_entry(loop: asyncio.AbstractEventLoop):
|
||||||
print(asciiart)
|
print(asciiart)
|
||||||
|
|
||||||
@@ -29,11 +27,11 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
|
|||||||
missing_deps = await deps.check_deps()
|
missing_deps = await deps.check_deps()
|
||||||
|
|
||||||
if missing_deps:
|
if missing_deps:
|
||||||
print("以下依赖包未安装,将自动安装,请完成后重启程序:")
|
print('以下依赖包未安装,将自动安装,请完成后重启程序:')
|
||||||
for dep in missing_deps:
|
for dep in missing_deps:
|
||||||
print("-", dep)
|
print('-', dep)
|
||||||
await deps.install_deps(missing_deps)
|
await deps.install_deps(missing_deps)
|
||||||
print("已自动安装缺失的依赖包,请重启程序。")
|
print('已自动安装缺失的依赖包,请重启程序。')
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
# check plugin deps
|
# check plugin deps
|
||||||
@@ -41,8 +39,10 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
|
|||||||
|
|
||||||
# 检查pydantic版本,如果没有 pydantic.v1,则把 pydantic 映射为 v1
|
# 检查pydantic版本,如果没有 pydantic.v1,则把 pydantic 映射为 v1
|
||||||
import pydantic.version
|
import pydantic.version
|
||||||
|
|
||||||
if pydantic.version.VERSION < '2.0':
|
if pydantic.version.VERSION < '2.0':
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
sys.modules['pydantic.v1'] = pydantic
|
sys.modules['pydantic.v1'] = pydantic
|
||||||
|
|
||||||
# 检查配置文件
|
# 检查配置文件
|
||||||
@@ -52,11 +52,12 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
|
|||||||
generated_files = await files.generate_files()
|
generated_files = await files.generate_files()
|
||||||
|
|
||||||
if generated_files:
|
if generated_files:
|
||||||
print("以下文件不存在,已自动生成:")
|
print('以下文件不存在,已自动生成:')
|
||||||
for file in generated_files:
|
for file in generated_files:
|
||||||
print("-", file)
|
print('-', file)
|
||||||
|
|
||||||
from pkg.core import boot
|
from pkg.core import boot
|
||||||
|
|
||||||
await boot.main(loop)
|
await boot.main(loop)
|
||||||
|
|
||||||
|
|
||||||
@@ -66,8 +67,8 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# 必须大于 3.10.1
|
# 必须大于 3.10.1
|
||||||
if sys.version_info < (3, 10, 1):
|
if sys.version_info < (3, 10, 1):
|
||||||
print("需要 Python 3.10.1 及以上版本,当前 Python 版本为:", sys.version)
|
print('需要 Python 3.10.1 及以上版本,当前 Python 版本为:', sys.version)
|
||||||
input("按任意键退出...")
|
input('按任意键退出...')
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
# 检查本目录是否有main.py,且包含LangBot字符串
|
# 检查本目录是否有main.py,且包含LangBot字符串
|
||||||
@@ -78,11 +79,11 @@ if __name__ == '__main__':
|
|||||||
else:
|
else:
|
||||||
with open('main.py', 'r', encoding='utf-8') as f:
|
with open('main.py', 'r', encoding='utf-8') as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
if "LangBot/main.py" not in content:
|
if 'LangBot/main.py' not in content:
|
||||||
invalid_pwd = True
|
invalid_pwd = True
|
||||||
if invalid_pwd:
|
if invalid_pwd:
|
||||||
print("请在 LangBot 项目根目录下以命令形式运行此程序。")
|
print('请在 LangBot 项目根目录下以命令形式运行此程序。')
|
||||||
input("按任意键退出...")
|
input('按任意键退出...')
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from ....core import app
|
|||||||
preregistered_groups: list[type[RouterGroup]] = []
|
preregistered_groups: list[type[RouterGroup]] = []
|
||||||
"""RouterGroup 的预注册列表"""
|
"""RouterGroup 的预注册列表"""
|
||||||
|
|
||||||
|
|
||||||
def group_class(name: str, path: str) -> None:
|
def group_class(name: str, path: str) -> None:
|
||||||
"""注册一个 RouterGroup"""
|
"""注册一个 RouterGroup"""
|
||||||
|
|
||||||
@@ -27,12 +28,12 @@ def group_class(name: str, path: str) -> None:
|
|||||||
|
|
||||||
class AuthType(enum.Enum):
|
class AuthType(enum.Enum):
|
||||||
"""认证类型"""
|
"""认证类型"""
|
||||||
|
|
||||||
NONE = 'none'
|
NONE = 'none'
|
||||||
USER_TOKEN = 'user-token'
|
USER_TOKEN = 'user-token'
|
||||||
|
|
||||||
|
|
||||||
class RouterGroup(abc.ABC):
|
class RouterGroup(abc.ABC):
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
path: str
|
path: str
|
||||||
@@ -49,17 +50,24 @@ class RouterGroup(abc.ABC):
|
|||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def route(self, rule: str, auth_type: AuthType = AuthType.USER_TOKEN, **options: typing.Any) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator
|
def route(
|
||||||
|
self,
|
||||||
|
rule: str,
|
||||||
|
auth_type: AuthType = AuthType.USER_TOKEN,
|
||||||
|
**options: typing.Any,
|
||||||
|
) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator
|
||||||
"""注册一个路由"""
|
"""注册一个路由"""
|
||||||
|
|
||||||
def decorator(f: RouteCallable) -> RouteCallable:
|
def decorator(f: RouteCallable) -> RouteCallable:
|
||||||
nonlocal rule
|
nonlocal rule
|
||||||
rule = self.path + rule
|
rule = self.path + rule
|
||||||
|
|
||||||
async def handler_error(*args, **kwargs):
|
async def handler_error(*args, **kwargs):
|
||||||
|
|
||||||
if auth_type == AuthType.USER_TOKEN:
|
if auth_type == AuthType.USER_TOKEN:
|
||||||
# 从Authorization头中获取token
|
# 从Authorization头中获取token
|
||||||
token = quart.request.headers.get('Authorization', '').replace('Bearer ', '')
|
token = quart.request.headers.get('Authorization', '').replace(
|
||||||
|
'Bearer ', ''
|
||||||
|
)
|
||||||
|
|
||||||
if not token:
|
if not token:
|
||||||
return self.http_status(401, -1, '未提供有效的用户令牌')
|
return self.http_status(401, -1, '未提供有效的用户令牌')
|
||||||
@@ -75,7 +83,7 @@ class RouterGroup(abc.ABC):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
return await f(*args, **kwargs)
|
return await f(*args, **kwargs)
|
||||||
except Exception as e: # 自动 500
|
except Exception: # 自动 500
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
# return self.http_status(500, -2, str(e))
|
# return self.http_status(500, -2, str(e))
|
||||||
return self.http_status(500, -2, 'internal server error')
|
return self.http_status(500, -2, 'internal server error')
|
||||||
@@ -91,19 +99,23 @@ class RouterGroup(abc.ABC):
|
|||||||
|
|
||||||
def success(self, data: typing.Any = None) -> quart.Response:
|
def success(self, data: typing.Any = None) -> quart.Response:
|
||||||
"""返回一个 200 响应"""
|
"""返回一个 200 响应"""
|
||||||
return quart.jsonify({
|
return quart.jsonify(
|
||||||
'code': 0,
|
{
|
||||||
'msg': 'ok',
|
'code': 0,
|
||||||
'data': data,
|
'msg': 'ok',
|
||||||
})
|
'data': data,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def fail(self, code: int, msg: str) -> quart.Response:
|
def fail(self, code: int, msg: str) -> quart.Response:
|
||||||
"""返回一个异常响应"""
|
"""返回一个异常响应"""
|
||||||
|
|
||||||
return quart.jsonify({
|
return quart.jsonify(
|
||||||
'code': code,
|
{
|
||||||
'msg': msg,
|
'code': code,
|
||||||
})
|
'msg': msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def http_status(self, status: int, code: int, msg: str) -> quart.Response:
|
def http_status(self, status: int, code: int, msg: str) -> quart.Response:
|
||||||
"""返回一个指定状态码的响应"""
|
"""返回一个指定状态码的响应"""
|
||||||
|
|||||||
@@ -1,32 +1,29 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import quart
|
import quart
|
||||||
|
|
||||||
from .....core import app
|
|
||||||
from .. import group
|
from .. import group
|
||||||
|
|
||||||
|
|
||||||
@group.group_class('logs', '/api/v1/logs')
|
@group.group_class('logs', '/api/v1/logs')
|
||||||
class LogsRouterGroup(group.RouterGroup):
|
class LogsRouterGroup(group.RouterGroup):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
|
|
||||||
start_page_number = int(quart.request.args.get('start_page_number', 0))
|
start_page_number = int(quart.request.args.get('start_page_number', 0))
|
||||||
start_offset = int(quart.request.args.get('start_offset', 0))
|
start_offset = int(quart.request.args.get('start_offset', 0))
|
||||||
|
|
||||||
logs_str, end_page_number, end_offset = self.ap.log_cache.get_log_by_pointer(
|
logs_str, end_page_number, end_offset = (
|
||||||
start_page_number=start_page_number,
|
self.ap.log_cache.get_log_by_pointer(
|
||||||
start_offset=start_offset
|
start_page_number=start_page_number, start_offset=start_offset
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.success(
|
return self.success(
|
||||||
data={
|
data={
|
||||||
"logs": logs_str,
|
'logs': logs_str,
|
||||||
"end_page_number": end_page_number,
|
'end_page_number': end_page_number,
|
||||||
"end_offset": end_offset
|
'end_offset': end_offset,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,34 +3,31 @@ from __future__ import annotations
|
|||||||
import quart
|
import quart
|
||||||
|
|
||||||
from .. import group
|
from .. import group
|
||||||
from .....entity.persistence import pipeline
|
|
||||||
|
|
||||||
|
|
||||||
@group.group_class('pipelines', '/api/v1/pipelines')
|
@group.group_class('pipelines', '/api/v1/pipelines')
|
||||||
class PipelinesRouterGroup(group.RouterGroup):
|
class PipelinesRouterGroup(group.RouterGroup):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
|
||||||
@self.route('', methods=['GET', 'POST'])
|
@self.route('', methods=['GET', 'POST'])
|
||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
if quart.request.method == 'GET':
|
if quart.request.method == 'GET':
|
||||||
return self.success(data={
|
return self.success(
|
||||||
'pipelines': await self.ap.pipeline_service.get_pipelines()
|
data={'pipelines': await self.ap.pipeline_service.get_pipelines()}
|
||||||
})
|
)
|
||||||
elif quart.request.method == 'POST':
|
elif quart.request.method == 'POST':
|
||||||
json_data = await quart.request.json
|
json_data = await quart.request.json
|
||||||
|
|
||||||
pipeline_uuid = await self.ap.pipeline_service.create_pipeline(json_data)
|
pipeline_uuid = await self.ap.pipeline_service.create_pipeline(
|
||||||
|
json_data
|
||||||
|
)
|
||||||
|
|
||||||
return self.success(data={
|
return self.success(data={'uuid': pipeline_uuid})
|
||||||
'uuid': pipeline_uuid
|
|
||||||
})
|
|
||||||
|
|
||||||
@self.route('/_/metadata', methods=['GET'])
|
@self.route('/_/metadata', methods=['GET'])
|
||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
return self.success(data={
|
return self.success(
|
||||||
'configs': await self.ap.pipeline_service.get_pipeline_metadata()
|
data={'configs': await self.ap.pipeline_service.get_pipeline_metadata()}
|
||||||
})
|
)
|
||||||
|
|
||||||
@self.route('/<pipeline_uuid>', methods=['GET', 'PUT', 'DELETE'])
|
@self.route('/<pipeline_uuid>', methods=['GET', 'PUT', 'DELETE'])
|
||||||
async def _(pipeline_uuid: str) -> str:
|
async def _(pipeline_uuid: str) -> str:
|
||||||
@@ -40,9 +37,7 @@ class PipelinesRouterGroup(group.RouterGroup):
|
|||||||
if pipeline is None:
|
if pipeline is None:
|
||||||
return self.http_status(404, -1, 'pipeline not found')
|
return self.http_status(404, -1, 'pipeline not found')
|
||||||
|
|
||||||
return self.success(data={
|
return self.success(data={'pipeline': pipeline})
|
||||||
'pipeline': pipeline
|
|
||||||
})
|
|
||||||
elif quart.request.method == 'PUT':
|
elif quart.request.method == 'PUT':
|
||||||
json_data = await quart.request.json
|
json_data = await quart.request.json
|
||||||
|
|
||||||
@@ -53,4 +48,3 @@ class PipelinesRouterGroup(group.RouterGroup):
|
|||||||
await self.ap.pipeline_service.delete_pipeline(pipeline_uuid)
|
await self.ap.pipeline_service.delete_pipeline(pipeline_uuid)
|
||||||
|
|
||||||
return self.success()
|
return self.success()
|
||||||
|
|
||||||
|
|||||||
@@ -5,29 +5,31 @@ from ... import group
|
|||||||
|
|
||||||
@group.group_class('adapters', '/api/v1/platform/adapters')
|
@group.group_class('adapters', '/api/v1/platform/adapters')
|
||||||
class AdaptersRouterGroup(group.RouterGroup):
|
class AdaptersRouterGroup(group.RouterGroup):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
@self.route('', methods=['GET'])
|
@self.route('', methods=['GET'])
|
||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
return self.success(data={
|
return self.success(
|
||||||
'adapters': self.ap.platform_mgr.get_available_adapters_info()
|
data={'adapters': self.ap.platform_mgr.get_available_adapters_info()}
|
||||||
})
|
)
|
||||||
|
|
||||||
@self.route('/<adapter_name>', methods=['GET'])
|
@self.route('/<adapter_name>', methods=['GET'])
|
||||||
async def _(adapter_name: str) -> str:
|
async def _(adapter_name: str) -> str:
|
||||||
adapter_info = self.ap.platform_mgr.get_available_adapter_info_by_name(adapter_name)
|
adapter_info = self.ap.platform_mgr.get_available_adapter_info_by_name(
|
||||||
|
adapter_name
|
||||||
|
)
|
||||||
|
|
||||||
if adapter_info is None:
|
if adapter_info is None:
|
||||||
return self.http_status(404, -1, 'adapter not found')
|
return self.http_status(404, -1, 'adapter not found')
|
||||||
|
|
||||||
return self.success(data={
|
return self.success(data={'adapter': adapter_info})
|
||||||
'adapter': adapter_info
|
|
||||||
})
|
|
||||||
|
|
||||||
@self.route('/<adapter_name>/icon', methods=['GET'])
|
@self.route('/<adapter_name>/icon', methods=['GET'])
|
||||||
async def _(adapter_name: str) -> quart.Response:
|
async def _(adapter_name: str) -> quart.Response:
|
||||||
|
adapter_manifest = (
|
||||||
adapter_manifest = self.ap.platform_mgr.get_available_adapter_manifest_by_name(adapter_name)
|
self.ap.platform_mgr.get_available_adapter_manifest_by_name(
|
||||||
|
adapter_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if adapter_manifest is None:
|
if adapter_manifest is None:
|
||||||
return self.http_status(404, -1, 'adapter not found')
|
return self.http_status(404, -1, 'adapter not found')
|
||||||
|
|||||||
@@ -5,20 +5,15 @@ from ... import group
|
|||||||
|
|
||||||
@group.group_class('bots', '/api/v1/platform/bots')
|
@group.group_class('bots', '/api/v1/platform/bots')
|
||||||
class BotsRouterGroup(group.RouterGroup):
|
class BotsRouterGroup(group.RouterGroup):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
@self.route('', methods=['GET', 'POST'])
|
@self.route('', methods=['GET', 'POST'])
|
||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
if quart.request.method == 'GET':
|
if quart.request.method == 'GET':
|
||||||
return self.success(data={
|
return self.success(data={'bots': await self.ap.bot_service.get_bots()})
|
||||||
'bots': await self.ap.bot_service.get_bots()
|
|
||||||
})
|
|
||||||
elif quart.request.method == 'POST':
|
elif quart.request.method == 'POST':
|
||||||
json_data = await quart.request.json
|
json_data = await quart.request.json
|
||||||
bot_uuid = await self.ap.bot_service.create_bot(json_data)
|
bot_uuid = await self.ap.bot_service.create_bot(json_data)
|
||||||
return self.success(data={
|
return self.success(data={'uuid': bot_uuid})
|
||||||
'uuid': bot_uuid
|
|
||||||
})
|
|
||||||
|
|
||||||
@self.route('/<bot_uuid>', methods=['GET', 'PUT', 'DELETE'])
|
@self.route('/<bot_uuid>', methods=['GET', 'PUT', 'DELETE'])
|
||||||
async def _(bot_uuid: str) -> str:
|
async def _(bot_uuid: str) -> str:
|
||||||
@@ -26,9 +21,7 @@ class BotsRouterGroup(group.RouterGroup):
|
|||||||
bot = await self.ap.bot_service.get_bot(bot_uuid)
|
bot = await self.ap.bot_service.get_bot(bot_uuid)
|
||||||
if bot is None:
|
if bot is None:
|
||||||
return self.http_status(404, -1, 'bot not found')
|
return self.http_status(404, -1, 'bot not found')
|
||||||
return self.success(data={
|
return self.success(data={'bot': bot})
|
||||||
'bot': bot
|
|
||||||
})
|
|
||||||
elif quart.request.method == 'PUT':
|
elif quart.request.method == 'PUT':
|
||||||
json_data = await quart.request.json
|
json_data = await quart.request.json
|
||||||
await self.ap.bot_service.update_bot(bot_uuid, json_data)
|
await self.ap.bot_service.update_bot(bot_uuid, json_data)
|
||||||
|
|||||||
@@ -1,17 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import quart
|
import quart
|
||||||
|
|
||||||
from .....core import app, taskmgr
|
from .....core import taskmgr
|
||||||
from .. import group
|
from .. import group
|
||||||
|
|
||||||
|
|
||||||
@group.group_class('plugins', '/api/v1/plugins')
|
@group.group_class('plugins', '/api/v1/plugins')
|
||||||
class PluginsRouterGroup(group.RouterGroup):
|
class PluginsRouterGroup(group.RouterGroup):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
@@ -19,63 +16,69 @@ class PluginsRouterGroup(group.RouterGroup):
|
|||||||
|
|
||||||
plugins_data = [plugin.model_dump() for plugin in plugins]
|
plugins_data = [plugin.model_dump() for plugin in plugins]
|
||||||
|
|
||||||
return self.success(data={
|
return self.success(data={'plugins': plugins_data})
|
||||||
'plugins': plugins_data
|
|
||||||
})
|
|
||||||
|
|
||||||
@self.route('/<author>/<plugin_name>/toggle', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN)
|
@self.route(
|
||||||
|
'/<author>/<plugin_name>/toggle',
|
||||||
|
methods=['PUT'],
|
||||||
|
auth_type=group.AuthType.USER_TOKEN,
|
||||||
|
)
|
||||||
async def _(author: str, plugin_name: str) -> str:
|
async def _(author: str, plugin_name: str) -> str:
|
||||||
data = await quart.request.json
|
data = await quart.request.json
|
||||||
target_enabled = data.get('target_enabled')
|
target_enabled = data.get('target_enabled')
|
||||||
await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled)
|
await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled)
|
||||||
return self.success()
|
return self.success()
|
||||||
|
|
||||||
@self.route('/<author>/<plugin_name>/update', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
@self.route(
|
||||||
|
'/<author>/<plugin_name>/update',
|
||||||
|
methods=['POST'],
|
||||||
|
auth_type=group.AuthType.USER_TOKEN,
|
||||||
|
)
|
||||||
async def _(author: str, plugin_name: str) -> str:
|
async def _(author: str, plugin_name: str) -> str:
|
||||||
ctx = taskmgr.TaskContext.new()
|
ctx = taskmgr.TaskContext.new()
|
||||||
wrapper = self.ap.task_mgr.create_user_task(
|
wrapper = self.ap.task_mgr.create_user_task(
|
||||||
self.ap.plugin_mgr.update_plugin(plugin_name, task_context=ctx),
|
self.ap.plugin_mgr.update_plugin(plugin_name, task_context=ctx),
|
||||||
kind="plugin-operation",
|
kind='plugin-operation',
|
||||||
name=f"plugin-update-{plugin_name}",
|
name=f'plugin-update-{plugin_name}',
|
||||||
label=f"更新插件 {plugin_name}",
|
label=f'更新插件 {plugin_name}',
|
||||||
context=ctx
|
context=ctx,
|
||||||
)
|
)
|
||||||
return self.success(data={
|
return self.success(data={'task_id': wrapper.id})
|
||||||
'task_id': wrapper.id
|
|
||||||
})
|
|
||||||
|
|
||||||
@self.route('/<author>/<plugin_name>', methods=['GET', 'DELETE'], auth_type=group.AuthType.USER_TOKEN)
|
@self.route(
|
||||||
|
'/<author>/<plugin_name>',
|
||||||
|
methods=['GET', 'DELETE'],
|
||||||
|
auth_type=group.AuthType.USER_TOKEN,
|
||||||
|
)
|
||||||
async def _(author: str, plugin_name: str) -> str:
|
async def _(author: str, plugin_name: str) -> str:
|
||||||
if quart.request.method == 'GET':
|
if quart.request.method == 'GET':
|
||||||
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
|
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
|
||||||
if plugin is None:
|
if plugin is None:
|
||||||
return self.http_status(404, -1, 'plugin not found')
|
return self.http_status(404, -1, 'plugin not found')
|
||||||
return self.success(data={
|
return self.success(data={'plugin': plugin.model_dump()})
|
||||||
'plugin': plugin.model_dump()
|
|
||||||
})
|
|
||||||
elif quart.request.method == 'DELETE':
|
elif quart.request.method == 'DELETE':
|
||||||
ctx = taskmgr.TaskContext.new()
|
ctx = taskmgr.TaskContext.new()
|
||||||
wrapper = self.ap.task_mgr.create_user_task(
|
wrapper = self.ap.task_mgr.create_user_task(
|
||||||
self.ap.plugin_mgr.uninstall_plugin(plugin_name, task_context=ctx),
|
self.ap.plugin_mgr.uninstall_plugin(plugin_name, task_context=ctx),
|
||||||
kind="plugin-operation",
|
kind='plugin-operation',
|
||||||
name=f'plugin-remove-{plugin_name}',
|
name=f'plugin-remove-{plugin_name}',
|
||||||
label=f'删除插件 {plugin_name}',
|
label=f'删除插件 {plugin_name}',
|
||||||
context=ctx
|
context=ctx,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.success(data={
|
return self.success(data={'task_id': wrapper.id})
|
||||||
'task_id': wrapper.id
|
|
||||||
})
|
|
||||||
|
|
||||||
@self.route('/<author>/<plugin_name>/config', methods=['GET', 'PUT'], auth_type=group.AuthType.USER_TOKEN)
|
@self.route(
|
||||||
|
'/<author>/<plugin_name>/config',
|
||||||
|
methods=['GET', 'PUT'],
|
||||||
|
auth_type=group.AuthType.USER_TOKEN,
|
||||||
|
)
|
||||||
async def _(author: str, plugin_name: str) -> quart.Response:
|
async def _(author: str, plugin_name: str) -> quart.Response:
|
||||||
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
|
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
|
||||||
if plugin is None:
|
if plugin is None:
|
||||||
return self.http_status(404, -1, 'plugin not found')
|
return self.http_status(404, -1, 'plugin not found')
|
||||||
if quart.request.method == 'GET':
|
if quart.request.method == 'GET':
|
||||||
return self.success(data={
|
return self.success(data={'config': plugin.plugin_config})
|
||||||
'config': plugin.plugin_config
|
|
||||||
})
|
|
||||||
elif quart.request.method == 'PUT':
|
elif quart.request.method == 'PUT':
|
||||||
data = await quart.request.json
|
data = await quart.request.json
|
||||||
|
|
||||||
@@ -89,7 +92,9 @@ class PluginsRouterGroup(group.RouterGroup):
|
|||||||
await self.ap.plugin_mgr.reorder_plugins(data.get('plugins'))
|
await self.ap.plugin_mgr.reorder_plugins(data.get('plugins'))
|
||||||
return self.success()
|
return self.success()
|
||||||
|
|
||||||
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
@self.route(
|
||||||
|
'/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN
|
||||||
|
)
|
||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
data = await quart.request.json
|
data = await quart.request.json
|
||||||
|
|
||||||
@@ -97,12 +102,10 @@ class PluginsRouterGroup(group.RouterGroup):
|
|||||||
short_source_str = data['source'][-8:]
|
short_source_str = data['source'][-8:]
|
||||||
wrapper = self.ap.task_mgr.create_user_task(
|
wrapper = self.ap.task_mgr.create_user_task(
|
||||||
self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx),
|
self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx),
|
||||||
kind="plugin-operation",
|
kind='plugin-operation',
|
||||||
name=f'plugin-install-github',
|
name='plugin-install-github',
|
||||||
label=f'安装插件 ...{short_source_str}',
|
label=f'安装插件 ...{short_source_str}',
|
||||||
context=ctx
|
context=ctx,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.success(data={
|
return self.success(data={'task_id': wrapper.id})
|
||||||
'task_id': wrapper.id
|
|
||||||
})
|
|
||||||
|
|||||||
@@ -1,28 +1,23 @@
|
|||||||
import quart
|
import quart
|
||||||
import uuid
|
|
||||||
|
|
||||||
from ... import group
|
from ... import group
|
||||||
from ......entity.persistence import model
|
|
||||||
|
|
||||||
|
|
||||||
@group.group_class('models/llm', '/api/v1/provider/models/llm')
|
@group.group_class('models/llm', '/api/v1/provider/models/llm')
|
||||||
class LLMModelsRouterGroup(group.RouterGroup):
|
class LLMModelsRouterGroup(group.RouterGroup):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
@self.route('', methods=['GET', 'POST'])
|
@self.route('', methods=['GET', 'POST'])
|
||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
if quart.request.method == 'GET':
|
if quart.request.method == 'GET':
|
||||||
return self.success(data={
|
return self.success(
|
||||||
'models': await self.ap.model_service.get_llm_models()
|
data={'models': await self.ap.model_service.get_llm_models()}
|
||||||
})
|
)
|
||||||
elif quart.request.method == 'POST':
|
elif quart.request.method == 'POST':
|
||||||
json_data = await quart.request.json
|
json_data = await quart.request.json
|
||||||
|
|
||||||
model_uuid = await self.ap.model_service.create_llm_model(json_data)
|
model_uuid = await self.ap.model_service.create_llm_model(json_data)
|
||||||
|
|
||||||
return self.success(data={
|
return self.success(data={'uuid': model_uuid})
|
||||||
'uuid': model_uuid
|
|
||||||
})
|
|
||||||
|
|
||||||
@self.route('/<model_uuid>', methods=['GET', 'DELETE'])
|
@self.route('/<model_uuid>', methods=['GET', 'DELETE'])
|
||||||
async def _(model_uuid: str) -> str:
|
async def _(model_uuid: str) -> str:
|
||||||
@@ -32,9 +27,7 @@ class LLMModelsRouterGroup(group.RouterGroup):
|
|||||||
if model is None:
|
if model is None:
|
||||||
return self.http_status(404, -1, 'model not found')
|
return self.http_status(404, -1, 'model not found')
|
||||||
|
|
||||||
return self.success(data={
|
return self.success(data={'model': model})
|
||||||
'model': model
|
|
||||||
})
|
|
||||||
# elif quart.request.method == 'PUT':
|
# elif quart.request.method == 'PUT':
|
||||||
# json_data = await quart.request.json
|
# json_data = await quart.request.json
|
||||||
|
|
||||||
|
|||||||
@@ -5,29 +5,31 @@ from ... import group
|
|||||||
|
|
||||||
@group.group_class('provider/requesters', '/api/v1/provider/requesters')
|
@group.group_class('provider/requesters', '/api/v1/provider/requesters')
|
||||||
class RequestersRouterGroup(group.RouterGroup):
|
class RequestersRouterGroup(group.RouterGroup):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
@self.route('', methods=['GET'])
|
@self.route('', methods=['GET'])
|
||||||
async def _() -> quart.Response:
|
async def _() -> quart.Response:
|
||||||
return self.success(data={
|
return self.success(
|
||||||
'requesters': self.ap.model_mgr.get_available_requesters_info()
|
data={'requesters': self.ap.model_mgr.get_available_requesters_info()}
|
||||||
})
|
)
|
||||||
|
|
||||||
@self.route('/<requester_name>', methods=['GET'])
|
@self.route('/<requester_name>', methods=['GET'])
|
||||||
async def _(requester_name: str) -> quart.Response:
|
async def _(requester_name: str) -> quart.Response:
|
||||||
|
requester_info = self.ap.model_mgr.get_available_requester_info_by_name(
|
||||||
requester_info = self.ap.model_mgr.get_available_requester_info_by_name(requester_name)
|
requester_name
|
||||||
|
)
|
||||||
|
|
||||||
if requester_info is None:
|
if requester_info is None:
|
||||||
return self.http_status(404, -1, 'requester not found')
|
return self.http_status(404, -1, 'requester not found')
|
||||||
|
|
||||||
return self.success(data={
|
return self.success(data={'requester': requester_info})
|
||||||
'requester': requester_info
|
|
||||||
})
|
|
||||||
|
|
||||||
@self.route('/<requester_name>/icon', methods=['GET'])
|
@self.route('/<requester_name>/icon', methods=['GET'])
|
||||||
async def _(requester_name: str) -> quart.Response:
|
async def _(requester_name: str) -> quart.Response:
|
||||||
requester_manifest = self.ap.model_mgr.get_available_requester_manifest_by_name(requester_name)
|
requester_manifest = (
|
||||||
|
self.ap.model_mgr.get_available_requester_manifest_by_name(
|
||||||
|
requester_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if requester_manifest is None:
|
if requester_manifest is None:
|
||||||
return self.http_status(404, -1, 'requester not found')
|
return self.http_status(404, -1, 'requester not found')
|
||||||
|
|||||||
@@ -1,23 +1,21 @@
|
|||||||
import quart
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from .....core import app, taskmgr
|
|
||||||
from .. import group
|
from .. import group
|
||||||
|
|
||||||
|
|
||||||
@group.group_class('stats', '/api/v1/stats')
|
@group.group_class('stats', '/api/v1/stats')
|
||||||
class StatsRouterGroup(group.RouterGroup):
|
class StatsRouterGroup(group.RouterGroup):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
@self.route('/basic', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
@self.route('/basic', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
|
|
||||||
conv_count = 0
|
conv_count = 0
|
||||||
for session in self.ap.sess_mgr.session_list:
|
for session in self.ap.sess_mgr.session_list:
|
||||||
conv_count += len(session.conversations if session.conversations is not None else [])
|
conv_count += len(
|
||||||
|
session.conversations if session.conversations is not None else []
|
||||||
|
)
|
||||||
|
|
||||||
return self.success(data={
|
return self.success(
|
||||||
'active_session_count': len(self.ap.sess_mgr.session_list),
|
data={
|
||||||
'conversation_count': conv_count,
|
'active_session_count': len(self.ap.sess_mgr.session_list),
|
||||||
'query_count': self.ap.query_pool.query_id_counter,
|
'conversation_count': conv_count,
|
||||||
})
|
'query_count': self.ap.query_pool.query_id_counter,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,42 +1,41 @@
|
|||||||
import quart
|
import quart
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from .....core import app, taskmgr
|
|
||||||
from .. import group
|
from .. import group
|
||||||
from .....utils import constants
|
from .....utils import constants
|
||||||
|
|
||||||
|
|
||||||
@group.group_class('system', '/api/v1/system')
|
@group.group_class('system', '/api/v1/system')
|
||||||
class SystemRouterGroup(group.RouterGroup):
|
class SystemRouterGroup(group.RouterGroup):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
@self.route('/info', methods=['GET'], auth_type=group.AuthType.NONE)
|
@self.route('/info', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
return self.success(
|
return self.success(
|
||||||
data={
|
data={
|
||||||
"version": constants.semantic_version,
|
'version': constants.semantic_version,
|
||||||
"debug": constants.debug_mode,
|
'debug': constants.debug_mode,
|
||||||
"enabled_platform_count": len(self.ap.platform_mgr.get_running_adapters())
|
'enabled_platform_count': len(
|
||||||
|
self.ap.platform_mgr.get_running_adapters()
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@self.route('/tasks', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
@self.route('/tasks', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
task_type = quart.request.args.get("type")
|
task_type = quart.request.args.get('type')
|
||||||
|
|
||||||
if task_type == '':
|
if task_type == '':
|
||||||
task_type = None
|
task_type = None
|
||||||
|
|
||||||
return self.success(
|
return self.success(data=self.ap.task_mgr.get_tasks_dict(task_type))
|
||||||
data=self.ap.task_mgr.get_tasks_dict(task_type)
|
|
||||||
)
|
|
||||||
|
|
||||||
@self.route('/tasks/<task_id>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
@self.route(
|
||||||
|
'/tasks/<task_id>', methods=['GET'], auth_type=group.AuthType.USER_TOKEN
|
||||||
|
)
|
||||||
async def _(task_id: str) -> str:
|
async def _(task_id: str) -> str:
|
||||||
task = self.ap.task_mgr.get_task_by_id(int(task_id))
|
task = self.ap.task_mgr.get_task_by_id(int(task_id))
|
||||||
|
|
||||||
if task is None:
|
if task is None:
|
||||||
return self.http_status(404, 404, "Task not found")
|
return self.http_status(404, 404, 'Task not found')
|
||||||
|
|
||||||
return self.success(data=task.to_dict())
|
return self.success(data=task.to_dict())
|
||||||
|
|
||||||
@@ -44,20 +43,20 @@ class SystemRouterGroup(group.RouterGroup):
|
|||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
json_data = await quart.request.json
|
json_data = await quart.request.json
|
||||||
|
|
||||||
scope = json_data.get("scope")
|
scope = json_data.get('scope')
|
||||||
|
|
||||||
await self.ap.reload(
|
await self.ap.reload(scope=scope)
|
||||||
scope=scope
|
|
||||||
)
|
|
||||||
return self.success()
|
return self.success()
|
||||||
|
|
||||||
@self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
@self.route(
|
||||||
|
'/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN
|
||||||
|
)
|
||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
if not constants.debug_mode:
|
if not constants.debug_mode:
|
||||||
return self.http_status(403, 403, "Forbidden")
|
return self.http_status(403, 403, 'Forbidden')
|
||||||
|
|
||||||
py_code = await quart.request.data
|
py_code = await quart.request.data
|
||||||
|
|
||||||
ap = self.ap
|
ap = self.ap
|
||||||
|
|
||||||
return self.success(data=exec(py_code, {"ap": ap}))
|
return self.success(data=exec(py_code, {'ap': ap}))
|
||||||
|
|||||||
@@ -1,21 +1,18 @@
|
|||||||
import quart
|
import quart
|
||||||
import jwt
|
|
||||||
import argon2
|
import argon2
|
||||||
|
|
||||||
from .. import group
|
from .. import group
|
||||||
from .....entity.persistence import user
|
|
||||||
|
|
||||||
|
|
||||||
@group.group_class('user', '/api/v1/user')
|
@group.group_class('user', '/api/v1/user')
|
||||||
class UserRouterGroup(group.RouterGroup):
|
class UserRouterGroup(group.RouterGroup):
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
@self.route('/init', methods=['GET', 'POST'], auth_type=group.AuthType.NONE)
|
@self.route('/init', methods=['GET', 'POST'], auth_type=group.AuthType.NONE)
|
||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
if quart.request.method == 'GET':
|
if quart.request.method == 'GET':
|
||||||
return self.success(data={
|
return self.success(
|
||||||
'initialized': await self.ap.user_service.is_initialized()
|
data={'initialized': await self.ap.user_service.is_initialized()}
|
||||||
})
|
)
|
||||||
|
|
||||||
if await self.ap.user_service.is_initialized():
|
if await self.ap.user_service.is_initialized():
|
||||||
return self.fail(1, '系统已初始化')
|
return self.fail(1, '系统已初始化')
|
||||||
@@ -34,18 +31,18 @@ class UserRouterGroup(group.RouterGroup):
|
|||||||
json_data = await quart.request.json
|
json_data = await quart.request.json
|
||||||
|
|
||||||
try:
|
try:
|
||||||
token = await self.ap.user_service.authenticate(json_data['user'], json_data['password'])
|
token = await self.ap.user_service.authenticate(
|
||||||
|
json_data['user'], json_data['password']
|
||||||
|
)
|
||||||
except argon2.exceptions.VerifyMismatchError:
|
except argon2.exceptions.VerifyMismatchError:
|
||||||
return self.fail(1, '用户名或密码错误')
|
return self.fail(1, '用户名或密码错误')
|
||||||
|
|
||||||
return self.success(data={
|
return self.success(data={'token': token})
|
||||||
'token': token
|
|
||||||
})
|
|
||||||
|
|
||||||
@self.route('/check-token', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
@self.route(
|
||||||
|
'/check-token', methods=['GET'], auth_type=group.AuthType.USER_TOKEN
|
||||||
|
)
|
||||||
async def _(user_email: str) -> str:
|
async def _(user_email: str) -> str:
|
||||||
token = await self.ap.user_service.generate_jwt_token(user_email)
|
token = await self.ap.user_service.generate_jwt_token(user_email)
|
||||||
|
|
||||||
return self.success(data={
|
return self.success(data={'token': token})
|
||||||
'token': token
|
|
||||||
})
|
|
||||||
|
|||||||
@@ -7,15 +7,19 @@ import quart
|
|||||||
import quart_cors
|
import quart_cors
|
||||||
|
|
||||||
from ....core import app, entities as core_entities
|
from ....core import app, entities as core_entities
|
||||||
|
from ....utils import importutil
|
||||||
|
|
||||||
from .groups import logs, system, plugins, stats, user, pipelines
|
from . import groups
|
||||||
from .groups.provider import models, requesters
|
|
||||||
from .groups.platform import bots, adapters
|
|
||||||
from . import group
|
from . import group
|
||||||
|
from .groups import provider as groups_provider
|
||||||
|
from .groups import platform as groups_platform
|
||||||
|
|
||||||
|
importutil.import_modules_in_pkg(groups)
|
||||||
|
importutil.import_modules_in_pkg(groups_provider)
|
||||||
|
importutil.import_modules_in_pkg(groups_platform)
|
||||||
|
|
||||||
|
|
||||||
class HTTPController:
|
class HTTPController:
|
||||||
|
|
||||||
ap: app.Application
|
ap: app.Application
|
||||||
|
|
||||||
quart_app: quart.Quart
|
quart_app: quart.Quart
|
||||||
@@ -23,7 +27,7 @@ class HTTPController:
|
|||||||
def __init__(self, ap: app.Application) -> None:
|
def __init__(self, ap: app.Application) -> None:
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
self.quart_app = quart.Quart(__name__)
|
self.quart_app = quart.Quart(__name__)
|
||||||
quart_cors.cors(self.quart_app, allow_origin="*")
|
quart_cors.cors(self.quart_app, allow_origin='*')
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
await self.register_routes()
|
await self.register_routes()
|
||||||
@@ -37,11 +41,9 @@ class HTTPController:
|
|||||||
|
|
||||||
async def exception_handler(*args, **kwargs):
|
async def exception_handler(*args, **kwargs):
|
||||||
try:
|
try:
|
||||||
await self.quart_app.run_task(
|
await self.quart_app.run_task(*args, **kwargs)
|
||||||
*args, **kwargs
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.ap.logger.error(f"启动 HTTP 服务失败: {e}")
|
self.ap.logger.error(f'启动 HTTP 服务失败: {e}')
|
||||||
|
|
||||||
self.ap.task_mgr.create_task(
|
self.ap.task_mgr.create_task(
|
||||||
exception_handler(
|
exception_handler(
|
||||||
@@ -49,63 +51,62 @@ class HTTPController:
|
|||||||
port=self.ap.instance_config.data['api']['port'],
|
port=self.ap.instance_config.data['api']['port'],
|
||||||
shutdown_trigger=shutdown_trigger_placeholder,
|
shutdown_trigger=shutdown_trigger_placeholder,
|
||||||
),
|
),
|
||||||
name="http-api-quart",
|
name='http-api-quart',
|
||||||
scopes=[core_entities.LifecycleControlScope.APPLICATION],
|
scopes=[core_entities.LifecycleControlScope.APPLICATION],
|
||||||
)
|
)
|
||||||
|
|
||||||
# await asyncio.sleep(5)
|
# await asyncio.sleep(5)
|
||||||
|
|
||||||
async def register_routes(self) -> None:
|
async def register_routes(self) -> None:
|
||||||
|
@self.quart_app.route('/healthz')
|
||||||
@self.quart_app.route("/healthz")
|
|
||||||
async def healthz():
|
async def healthz():
|
||||||
return {"code": 0, "msg": "ok"}
|
return {'code': 0, 'msg': 'ok'}
|
||||||
|
|
||||||
for g in group.preregistered_groups:
|
for g in group.preregistered_groups:
|
||||||
ginst = g(self.ap, self.quart_app)
|
ginst = g(self.ap, self.quart_app)
|
||||||
await ginst.initialize()
|
await ginst.initialize()
|
||||||
|
|
||||||
frontend_path = "web/out"
|
frontend_path = 'web/out'
|
||||||
|
|
||||||
@self.quart_app.route("/")
|
@self.quart_app.route('/')
|
||||||
async def index():
|
async def index():
|
||||||
return await quart.send_from_directory(frontend_path, "index.html", mimetype="text/html")
|
return await quart.send_from_directory(
|
||||||
|
frontend_path, 'index.html', mimetype='text/html'
|
||||||
|
)
|
||||||
|
|
||||||
@self.quart_app.route("/<path:path>")
|
@self.quart_app.route('/<path:path>')
|
||||||
async def static_file(path: str):
|
async def static_file(path: str):
|
||||||
if not os.path.exists(os.path.join(frontend_path, path)):
|
if not os.path.exists(os.path.join(frontend_path, path)):
|
||||||
if os.path.exists(os.path.join(frontend_path, path+".html")):
|
if os.path.exists(os.path.join(frontend_path, path + '.html')):
|
||||||
path += '.html'
|
path += '.html'
|
||||||
else:
|
else:
|
||||||
return await quart.send_from_directory(frontend_path, '404.html')
|
return await quart.send_from_directory(frontend_path, '404.html')
|
||||||
|
|
||||||
mimetype = None
|
mimetype = None
|
||||||
|
|
||||||
if path.endswith(".html"):
|
if path.endswith('.html'):
|
||||||
mimetype = "text/html"
|
mimetype = 'text/html'
|
||||||
elif path.endswith(".js"):
|
elif path.endswith('.js'):
|
||||||
mimetype = "application/javascript"
|
mimetype = 'application/javascript'
|
||||||
elif path.endswith(".css"):
|
elif path.endswith('.css'):
|
||||||
mimetype = "text/css"
|
mimetype = 'text/css'
|
||||||
elif path.endswith(".png"):
|
elif path.endswith('.png'):
|
||||||
mimetype = "image/png"
|
mimetype = 'image/png'
|
||||||
elif path.endswith(".jpg"):
|
elif path.endswith('.jpg'):
|
||||||
mimetype = "image/jpeg"
|
mimetype = 'image/jpeg'
|
||||||
elif path.endswith(".jpeg"):
|
elif path.endswith('.jpeg'):
|
||||||
mimetype = "image/jpeg"
|
mimetype = 'image/jpeg'
|
||||||
elif path.endswith(".gif"):
|
elif path.endswith('.gif'):
|
||||||
mimetype = "image/gif"
|
mimetype = 'image/gif'
|
||||||
elif path.endswith(".svg"):
|
elif path.endswith('.svg'):
|
||||||
mimetype = "image/svg+xml"
|
mimetype = 'image/svg+xml'
|
||||||
elif path.endswith(".ico"):
|
elif path.endswith('.ico'):
|
||||||
mimetype = "image/x-icon"
|
mimetype = 'image/x-icon'
|
||||||
elif path.endswith(".json"):
|
elif path.endswith('.json'):
|
||||||
mimetype = "application/json"
|
mimetype = 'application/json'
|
||||||
elif path.endswith(".txt"):
|
elif path.endswith('.txt'):
|
||||||
mimetype = "text/plain"
|
mimetype = 'text/plain'
|
||||||
|
|
||||||
return await quart.send_from_directory(
|
return await quart.send_from_directory(
|
||||||
frontend_path,
|
frontend_path, path, mimetype=mimetype
|
||||||
path,
|
|
||||||
mimetype=mimetype
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import datetime
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
from ....core import app
|
from ....core import app
|
||||||
@@ -33,7 +32,9 @@ class BotService:
|
|||||||
async def get_bot(self, bot_uuid: str) -> dict | None:
|
async def get_bot(self, bot_uuid: str) -> dict | None:
|
||||||
"""获取机器人"""
|
"""获取机器人"""
|
||||||
result = await self.ap.persistence_mgr.execute_async(
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
|
sqlalchemy.select(persistence_bot.Bot).where(
|
||||||
|
persistence_bot.Bot.uuid == bot_uuid
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
bot = result.first()
|
bot = result.first()
|
||||||
@@ -50,7 +51,9 @@ class BotService:
|
|||||||
|
|
||||||
# checkout the default pipeline
|
# checkout the default pipeline
|
||||||
result = await self.ap.persistence_mgr.execute_async(
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.is_default == True)
|
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
|
||||||
|
persistence_pipeline.LegacyPipeline.is_default == True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
pipeline = result.first()
|
pipeline = result.first()
|
||||||
if pipeline is not None:
|
if pipeline is not None:
|
||||||
@@ -75,16 +78,21 @@ class BotService:
|
|||||||
# set use_pipeline_name
|
# set use_pipeline_name
|
||||||
if 'use_pipeline_uuid' in bot_data:
|
if 'use_pipeline_uuid' in bot_data:
|
||||||
result = await self.ap.persistence_mgr.execute_async(
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == bot_data['use_pipeline_uuid'])
|
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
|
||||||
|
persistence_pipeline.LegacyPipeline.uuid
|
||||||
|
== bot_data['use_pipeline_uuid']
|
||||||
|
)
|
||||||
)
|
)
|
||||||
pipeline = result.first()
|
pipeline = result.first()
|
||||||
if pipeline is not None:
|
if pipeline is not None:
|
||||||
bot_data['use_pipeline_name'] = pipeline.name
|
bot_data['use_pipeline_name'] = pipeline.name
|
||||||
else:
|
else:
|
||||||
raise Exception("Pipeline not found")
|
raise Exception('Pipeline not found')
|
||||||
|
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.update(persistence_bot.Bot).values(bot_data).where(persistence_bot.Bot.uuid == bot_uuid)
|
sqlalchemy.update(persistence_bot.Bot)
|
||||||
|
.values(bot_data)
|
||||||
|
.where(persistence_bot.Bot.uuid == bot_uuid)
|
||||||
)
|
)
|
||||||
await self.ap.platform_mgr.remove_bot(bot_uuid)
|
await self.ap.platform_mgr.remove_bot(bot_uuid)
|
||||||
|
|
||||||
@@ -100,7 +108,7 @@ class BotService:
|
|||||||
"""删除机器人"""
|
"""删除机器人"""
|
||||||
await self.ap.platform_mgr.remove_bot(bot_uuid)
|
await self.ap.platform_mgr.remove_bot(bot_uuid)
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.delete(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
|
sqlalchemy.delete(persistence_bot.Bot).where(
|
||||||
|
persistence_bot.Bot.uuid == bot_uuid
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import datetime
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
from ....core import app
|
from ....core import app
|
||||||
@@ -10,7 +9,6 @@ from ....entity.persistence import pipeline as persistence_pipeline
|
|||||||
|
|
||||||
|
|
||||||
class ModelsService:
|
class ModelsService:
|
||||||
|
|
||||||
ap: app.Application
|
ap: app.Application
|
||||||
|
|
||||||
def __init__(self, ap: app.Application) -> None:
|
def __init__(self, ap: app.Application) -> None:
|
||||||
@@ -28,13 +26,10 @@ class ModelsService:
|
|||||||
]
|
]
|
||||||
|
|
||||||
async def create_llm_model(self, model_data: dict) -> str:
|
async def create_llm_model(self, model_data: dict) -> str:
|
||||||
|
|
||||||
model_data['uuid'] = str(uuid.uuid4())
|
model_data['uuid'] = str(uuid.uuid4())
|
||||||
|
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.insert(persistence_model.LLMModel).values(
|
sqlalchemy.insert(persistence_model.LLMModel).values(**model_data)
|
||||||
**model_data
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
llm_model = await self.get_llm_model(model_data['uuid'])
|
llm_model = await self.get_llm_model(model_data['uuid'])
|
||||||
@@ -43,22 +38,24 @@ class ModelsService:
|
|||||||
|
|
||||||
# check if default pipeline has no model bound
|
# check if default pipeline has no model bound
|
||||||
result = await self.ap.persistence_mgr.execute_async(
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.is_default == True)
|
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
|
||||||
|
persistence_pipeline.LegacyPipeline.is_default == True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
pipeline = result.first()
|
pipeline = result.first()
|
||||||
if pipeline is not None and pipeline.config['ai']['local-agent']['model'] == '':
|
if pipeline is not None and pipeline.config['ai']['local-agent']['model'] == '':
|
||||||
pipeline_config = pipeline.config
|
pipeline_config = pipeline.config
|
||||||
pipeline_config['ai']['local-agent']['model'] = model_data['uuid']
|
pipeline_config['ai']['local-agent']['model'] = model_data['uuid']
|
||||||
pipeline_data = {
|
pipeline_data = {'config': pipeline_config}
|
||||||
"config": pipeline_config
|
|
||||||
}
|
|
||||||
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
|
await self.ap.pipeline_service.update_pipeline(pipeline.uuid, pipeline_data)
|
||||||
|
|
||||||
return model_data['uuid']
|
return model_data['uuid']
|
||||||
|
|
||||||
async def get_llm_model(self, model_uuid: str) -> dict | None:
|
async def get_llm_model(self, model_uuid: str) -> dict | None:
|
||||||
result = await self.ap.persistence_mgr.execute_async(
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.select(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
|
sqlalchemy.select(persistence_model.LLMModel).where(
|
||||||
|
persistence_model.LLMModel.uuid == model_uuid
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
model = result.first()
|
model = result.first()
|
||||||
@@ -66,14 +63,18 @@ class ModelsService:
|
|||||||
if model is None:
|
if model is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model)
|
return self.ap.persistence_mgr.serialize_model(
|
||||||
|
persistence_model.LLMModel, model
|
||||||
|
)
|
||||||
|
|
||||||
async def update_llm_model(self, model_uuid: str, model_data: dict) -> None:
|
async def update_llm_model(self, model_uuid: str, model_data: dict) -> None:
|
||||||
if 'uuid' in model_data:
|
if 'uuid' in model_data:
|
||||||
del model_data['uuid']
|
del model_data['uuid']
|
||||||
|
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.update(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid).values(**model_data)
|
sqlalchemy.update(persistence_model.LLMModel)
|
||||||
|
.where(persistence_model.LLMModel.uuid == model_uuid)
|
||||||
|
.values(**model_data)
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.ap.model_mgr.remove_llm_model(model_uuid)
|
await self.ap.model_mgr.remove_llm_model(model_uuid)
|
||||||
@@ -84,7 +85,9 @@ class ModelsService:
|
|||||||
|
|
||||||
async def delete_llm_model(self, model_uuid: str) -> None:
|
async def delete_llm_model(self, model_uuid: str) -> None:
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
|
sqlalchemy.delete(persistence_model.LLMModel).where(
|
||||||
|
persistence_model.LLMModel.uuid == model_uuid
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.ap.model_mgr.remove_llm_model(model_uuid)
|
await self.ap.model_mgr.remove_llm_model(model_uuid)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
import datetime
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
from ....core import app
|
from ....core import app
|
||||||
@@ -10,18 +9,18 @@ from ....entity.persistence import pipeline as persistence_pipeline
|
|||||||
|
|
||||||
|
|
||||||
default_stage_order = [
|
default_stage_order = [
|
||||||
"GroupRespondRuleCheckStage", # 群响应规则检查
|
'GroupRespondRuleCheckStage', # 群响应规则检查
|
||||||
"BanSessionCheckStage", # 封禁会话检查
|
'BanSessionCheckStage', # 封禁会话检查
|
||||||
"PreContentFilterStage", # 内容过滤前置阶段
|
'PreContentFilterStage', # 内容过滤前置阶段
|
||||||
"PreProcessor", # 预处理器
|
'PreProcessor', # 预处理器
|
||||||
"ConversationMessageTruncator", # 会话消息截断器
|
'ConversationMessageTruncator', # 会话消息截断器
|
||||||
"RequireRateLimitOccupancy", # 请求速率限制占用
|
'RequireRateLimitOccupancy', # 请求速率限制占用
|
||||||
"MessageProcessor", # 处理器
|
'MessageProcessor', # 处理器
|
||||||
"ReleaseRateLimitOccupancy", # 释放速率限制占用
|
'ReleaseRateLimitOccupancy', # 释放速率限制占用
|
||||||
"PostContentFilterStage", # 内容过滤后置阶段
|
'PostContentFilterStage', # 内容过滤后置阶段
|
||||||
"ResponseWrapper", # 响应包装器
|
'ResponseWrapper', # 响应包装器
|
||||||
"LongTextProcessStage", # 长文本处理
|
'LongTextProcessStage', # 长文本处理
|
||||||
"SendResponseBackStage", # 发送响应
|
'SendResponseBackStage', # 发送响应
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -36,7 +35,7 @@ class PipelineService:
|
|||||||
self.ap.pipeline_config_meta_trigger.data,
|
self.ap.pipeline_config_meta_trigger.data,
|
||||||
self.ap.pipeline_config_meta_safety.data,
|
self.ap.pipeline_config_meta_safety.data,
|
||||||
self.ap.pipeline_config_meta_ai.data,
|
self.ap.pipeline_config_meta_ai.data,
|
||||||
self.ap.pipeline_config_meta_output.data
|
self.ap.pipeline_config_meta_output.data,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def get_pipelines(self) -> list[dict]:
|
async def get_pipelines(self) -> list[dict]:
|
||||||
@@ -46,13 +45,17 @@ class PipelineService:
|
|||||||
|
|
||||||
pipelines = result.all()
|
pipelines = result.all()
|
||||||
return [
|
return [
|
||||||
self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
self.ap.persistence_mgr.serialize_model(
|
||||||
|
persistence_pipeline.LegacyPipeline, pipeline
|
||||||
|
)
|
||||||
for pipeline in pipelines
|
for pipeline in pipelines
|
||||||
]
|
]
|
||||||
|
|
||||||
async def get_pipeline(self, pipeline_uuid: str) -> dict | None:
|
async def get_pipeline(self, pipeline_uuid: str) -> dict | None:
|
||||||
result = await self.ap.persistence_mgr.execute_async(
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid)
|
sqlalchemy.select(persistence_pipeline.LegacyPipeline).where(
|
||||||
|
persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline = result.first()
|
pipeline = result.first()
|
||||||
@@ -60,17 +63,23 @@ class PipelineService:
|
|||||||
if pipeline is None:
|
if pipeline is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
return self.ap.persistence_mgr.serialize_model(
|
||||||
|
persistence_pipeline.LegacyPipeline, pipeline
|
||||||
|
)
|
||||||
|
|
||||||
async def create_pipeline(self, pipeline_data: dict, default: bool = False) -> str:
|
async def create_pipeline(self, pipeline_data: dict, default: bool = False) -> str:
|
||||||
pipeline_data['uuid'] = str(uuid.uuid4())
|
pipeline_data['uuid'] = str(uuid.uuid4())
|
||||||
pipeline_data['for_version'] = self.ap.ver_mgr.get_current_version()
|
pipeline_data['for_version'] = self.ap.ver_mgr.get_current_version()
|
||||||
pipeline_data['stages'] = default_stage_order.copy()
|
pipeline_data['stages'] = default_stage_order.copy()
|
||||||
pipeline_data['is_default'] = default
|
pipeline_data['is_default'] = default
|
||||||
pipeline_data['config'] = json.load(open('templates/default-pipeline-config.json', 'r', encoding='utf-8'))
|
pipeline_data['config'] = json.load(
|
||||||
|
open('templates/default-pipeline-config.json', 'r', encoding='utf-8')
|
||||||
|
)
|
||||||
|
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(**pipeline_data)
|
sqlalchemy.insert(persistence_pipeline.LegacyPipeline).values(
|
||||||
|
**pipeline_data
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline = await self.get_pipeline(pipeline_data['uuid'])
|
pipeline = await self.get_pipeline(pipeline_data['uuid'])
|
||||||
@@ -90,7 +99,9 @@ class PipelineService:
|
|||||||
del pipeline_data['is_default']
|
del pipeline_data['is_default']
|
||||||
|
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid).values(**pipeline_data)
|
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
|
||||||
|
.where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid)
|
||||||
|
.values(**pipeline_data)
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid)
|
await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid)
|
||||||
@@ -101,6 +112,8 @@ class PipelineService:
|
|||||||
|
|
||||||
async def delete_pipeline(self, pipeline_uuid: str) -> None:
|
async def delete_pipeline(self, pipeline_uuid: str) -> None:
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.delete(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid)
|
sqlalchemy.delete(persistence_pipeline.LegacyPipeline).where(
|
||||||
|
persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid
|
||||||
|
)
|
||||||
)
|
)
|
||||||
await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid)
|
await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid)
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from ....utils import constants
|
|||||||
|
|
||||||
|
|
||||||
class UserService:
|
class UserService:
|
||||||
|
|
||||||
ap: app.Application
|
ap: app.Application
|
||||||
|
|
||||||
def __init__(self, ap: app.Application) -> None:
|
def __init__(self, ap: app.Application) -> None:
|
||||||
@@ -32,8 +31,7 @@ class UserService:
|
|||||||
|
|
||||||
await self.ap.persistence_mgr.execute_async(
|
await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.insert(user.User).values(
|
sqlalchemy.insert(user.User).values(
|
||||||
user=user_email,
|
user=user_email, password=hashed_password
|
||||||
password=hashed_password
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -61,8 +59,8 @@ class UserService:
|
|||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
'user': user_email,
|
'user': user_email,
|
||||||
'iss': 'LangBot-'+constants.edition,
|
'iss': 'LangBot-' + constants.edition,
|
||||||
'exp': datetime.datetime.now() + datetime.timedelta(seconds=jwt_expire)
|
'exp': datetime.datetime.now() + datetime.timedelta(seconds=jwt_expire),
|
||||||
}
|
}
|
||||||
|
|
||||||
return jwt.encode(payload, jwt_secret, algorithm='HS256')
|
return jwt.encode(payload, jwt_secret, algorithm='HS256')
|
||||||
|
|||||||
@@ -3,11 +3,9 @@ from __future__ import annotations
|
|||||||
import abc
|
import abc
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
|
||||||
|
|
||||||
from ...core import app, entities as core_entities
|
from ...core import app, entities as core_entities
|
||||||
|
|
||||||
@@ -38,22 +36,22 @@ class APIGroup(metaclass=abc.ABCMeta):
|
|||||||
"""
|
"""
|
||||||
执行请求
|
执行请求
|
||||||
"""
|
"""
|
||||||
self._runtime_info["account_id"] = "-1"
|
self._runtime_info['account_id'] = '-1'
|
||||||
|
|
||||||
url = self.prefix + path
|
url = self.prefix + path
|
||||||
data = json.dumps(data)
|
data = json.dumps(data)
|
||||||
headers["Content-Type"] = "application/json"
|
headers['Content-Type'] = 'application/json'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.request(
|
async with session.request(
|
||||||
method, url, data=data, params=params, headers=headers, **kwargs
|
method, url, data=data, params=params, headers=headers, **kwargs
|
||||||
) as resp:
|
) as resp:
|
||||||
self.ap.logger.debug("data: %s", data)
|
self.ap.logger.debug('data: %s', data)
|
||||||
self.ap.logger.debug("ret: %s", await resp.text())
|
self.ap.logger.debug('ret: %s', await resp.text())
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.ap.logger.debug(f"上报失败: {e}")
|
self.ap.logger.debug(f'上报失败: {e}')
|
||||||
|
|
||||||
async def do(
|
async def do(
|
||||||
self,
|
self,
|
||||||
@@ -68,8 +66,8 @@ class APIGroup(metaclass=abc.ABCMeta):
|
|||||||
|
|
||||||
return self.ap.task_mgr.create_task(
|
return self.ap.task_mgr.create_task(
|
||||||
self._do(method, path, data, params, headers, **kwargs),
|
self._do(method, path, data, params, headers, **kwargs),
|
||||||
kind="telemetry-operation",
|
kind='telemetry-operation',
|
||||||
name=f"{method} {path}",
|
name=f'{method} {path}',
|
||||||
scopes=[core_entities.LifecycleControlScope.APPLICATION],
|
scopes=[core_entities.LifecycleControlScope.APPLICATION],
|
||||||
).task
|
).task
|
||||||
|
|
||||||
@@ -80,7 +78,7 @@ class APIGroup(metaclass=abc.ABCMeta):
|
|||||||
def basic_info(self):
|
def basic_info(self):
|
||||||
"""获取基本信息"""
|
"""获取基本信息"""
|
||||||
basic_info = APIGroup._basic_info.copy()
|
basic_info = APIGroup._basic_info.copy()
|
||||||
basic_info["rid"] = self.gen_rid()
|
basic_info['rid'] = self.gen_rid()
|
||||||
return basic_info
|
return basic_info
|
||||||
|
|
||||||
def runtime_info(self):
|
def runtime_info(self):
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ class V2MainDataAPI(apigroup.APIGroup):
|
|||||||
|
|
||||||
def __init__(self, prefix: str, ap: app.Application):
|
def __init__(self, prefix: str, ap: app.Application):
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
super().__init__(prefix+"/main", ap)
|
super().__init__(prefix + '/main', ap)
|
||||||
|
|
||||||
async def do(self, *args, **kwargs):
|
async def do(self, *args, **kwargs):
|
||||||
if not self.ap.instance_config.data['telemetry']['report']:
|
if not self.ap.instance_config.data['telemetry']['report']:
|
||||||
@@ -25,17 +25,17 @@ class V2MainDataAPI(apigroup.APIGroup):
|
|||||||
):
|
):
|
||||||
"""提交更新记录"""
|
"""提交更新记录"""
|
||||||
return await self.do(
|
return await self.do(
|
||||||
"POST",
|
'POST',
|
||||||
"/update",
|
'/update',
|
||||||
data={
|
data={
|
||||||
"basic": self.basic_info(),
|
'basic': self.basic_info(),
|
||||||
"update_info": {
|
'update_info': {
|
||||||
"spent_seconds": spent_seconds,
|
'spent_seconds': spent_seconds,
|
||||||
"infer_reason": infer_reason,
|
'infer_reason': infer_reason,
|
||||||
"old_version": old_version,
|
'old_version': old_version,
|
||||||
"new_version": new_version,
|
'new_version': new_version,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def post_announcement_showed(
|
async def post_announcement_showed(
|
||||||
@@ -44,12 +44,12 @@ class V2MainDataAPI(apigroup.APIGroup):
|
|||||||
):
|
):
|
||||||
"""提交公告已阅"""
|
"""提交公告已阅"""
|
||||||
return await self.do(
|
return await self.do(
|
||||||
"POST",
|
'POST',
|
||||||
"/announcement",
|
'/announcement',
|
||||||
data={
|
data={
|
||||||
"basic": self.basic_info(),
|
'basic': self.basic_info(),
|
||||||
"announcement_info": {
|
'announcement_info': {
|
||||||
"ids": ids,
|
'ids': ids,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,39 +9,33 @@ class V2PluginDataAPI(apigroup.APIGroup):
|
|||||||
|
|
||||||
def __init__(self, prefix: str, ap: app.Application):
|
def __init__(self, prefix: str, ap: app.Application):
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
super().__init__(prefix+"/plugin", ap)
|
super().__init__(prefix + '/plugin', ap)
|
||||||
|
|
||||||
async def do(self, *args, **kwargs):
|
async def do(self, *args, **kwargs):
|
||||||
if not self.ap.instance_config.data['telemetry']['report']:
|
if not self.ap.instance_config.data['telemetry']['report']:
|
||||||
return None
|
return None
|
||||||
return await super().do(*args, **kwargs)
|
return await super().do(*args, **kwargs)
|
||||||
|
|
||||||
async def post_install_record(
|
async def post_install_record(self, plugin: dict):
|
||||||
self,
|
|
||||||
plugin: dict
|
|
||||||
):
|
|
||||||
"""提交插件安装记录"""
|
"""提交插件安装记录"""
|
||||||
return await self.do(
|
return await self.do(
|
||||||
"POST",
|
'POST',
|
||||||
"/install",
|
'/install',
|
||||||
data={
|
data={
|
||||||
"basic": self.basic_info(),
|
'basic': self.basic_info(),
|
||||||
"plugin": plugin,
|
'plugin': plugin,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def post_remove_record(
|
async def post_remove_record(self, plugin: dict):
|
||||||
self,
|
|
||||||
plugin: dict
|
|
||||||
):
|
|
||||||
"""提交插件卸载记录"""
|
"""提交插件卸载记录"""
|
||||||
return await self.do(
|
return await self.do(
|
||||||
"POST",
|
'POST',
|
||||||
"/remove",
|
'/remove',
|
||||||
data={
|
data={
|
||||||
"basic": self.basic_info(),
|
'basic': self.basic_info(),
|
||||||
"plugin": plugin,
|
'plugin': plugin,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def post_update_record(
|
async def post_update_record(
|
||||||
@@ -52,14 +46,14 @@ class V2PluginDataAPI(apigroup.APIGroup):
|
|||||||
):
|
):
|
||||||
"""提交插件更新记录"""
|
"""提交插件更新记录"""
|
||||||
return await self.do(
|
return await self.do(
|
||||||
"POST",
|
'POST',
|
||||||
"/update",
|
'/update',
|
||||||
data={
|
data={
|
||||||
"basic": self.basic_info(),
|
'basic': self.basic_info(),
|
||||||
"plugin": plugin,
|
'plugin': plugin,
|
||||||
"update_info": {
|
'update_info': {
|
||||||
"old_version": old_version,
|
'old_version': old_version,
|
||||||
"new_version": new_version,
|
'new_version': new_version,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ class V2UsageDataAPI(apigroup.APIGroup):
|
|||||||
|
|
||||||
def __init__(self, prefix: str, ap: app.Application):
|
def __init__(self, prefix: str, ap: app.Application):
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
super().__init__(prefix+"/usage", ap)
|
super().__init__(prefix + '/usage', ap)
|
||||||
|
|
||||||
async def do(self, *args, **kwargs):
|
async def do(self, *args, **kwargs):
|
||||||
if not self.ap.instance_config.data['telemetry']['report']:
|
if not self.ap.instance_config.data['telemetry']['report']:
|
||||||
@@ -28,23 +28,23 @@ class V2UsageDataAPI(apigroup.APIGroup):
|
|||||||
):
|
):
|
||||||
"""提交请求记录"""
|
"""提交请求记录"""
|
||||||
return await self.do(
|
return await self.do(
|
||||||
"POST",
|
'POST',
|
||||||
"/query",
|
'/query',
|
||||||
data={
|
data={
|
||||||
"basic": self.basic_info(),
|
'basic': self.basic_info(),
|
||||||
"runtime": self.runtime_info(),
|
'runtime': self.runtime_info(),
|
||||||
"session_info": {
|
'session_info': {
|
||||||
"type": session_type,
|
'type': session_type,
|
||||||
"id": session_id,
|
'id': session_id,
|
||||||
},
|
},
|
||||||
"query_info": {
|
'query_info': {
|
||||||
"ability_provider": query_ability_provider,
|
'ability_provider': query_ability_provider,
|
||||||
"usage": usage,
|
'usage': usage,
|
||||||
"model_name": model_name,
|
'model_name': model_name,
|
||||||
"response_seconds": response_seconds,
|
'response_seconds': response_seconds,
|
||||||
"retry_times": retry_times,
|
'retry_times': retry_times,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def post_event_record(
|
async def post_event_record(
|
||||||
@@ -54,16 +54,16 @@ class V2UsageDataAPI(apigroup.APIGroup):
|
|||||||
):
|
):
|
||||||
"""提交事件触发记录"""
|
"""提交事件触发记录"""
|
||||||
return await self.do(
|
return await self.do(
|
||||||
"POST",
|
'POST',
|
||||||
"/event",
|
'/event',
|
||||||
data={
|
data={
|
||||||
"basic": self.basic_info(),
|
'basic': self.basic_info(),
|
||||||
"runtime": self.runtime_info(),
|
'runtime': self.runtime_info(),
|
||||||
"plugins": plugins,
|
'plugins': plugins,
|
||||||
"event_info": {
|
'event_info': {
|
||||||
"name": event_name,
|
'name': event_name,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def post_function_record(
|
async def post_function_record(
|
||||||
@@ -74,15 +74,14 @@ class V2UsageDataAPI(apigroup.APIGroup):
|
|||||||
):
|
):
|
||||||
"""提交内容函数使用记录"""
|
"""提交内容函数使用记录"""
|
||||||
return await self.do(
|
return await self.do(
|
||||||
"POST",
|
'POST',
|
||||||
"/function",
|
'/function',
|
||||||
data={
|
data={
|
||||||
"basic": self.basic_info(),
|
'basic': self.basic_info(),
|
||||||
"plugin": plugin,
|
'plugin': plugin,
|
||||||
"function_info": {
|
'function_info': {
|
||||||
"name": function_name,
|
'name': function_name,
|
||||||
"description": function_description,
|
'description': function_description,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -21,10 +21,16 @@ class V2CenterAPI:
|
|||||||
plugin: plugin.V2PluginDataAPI = None
|
plugin: plugin.V2PluginDataAPI = None
|
||||||
"""插件 API 组"""
|
"""插件 API 组"""
|
||||||
|
|
||||||
def __init__(self, ap: app.Application, backend_url: str, basic_info: dict = None, runtime_info: dict = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
ap: app.Application,
|
||||||
|
backend_url: str,
|
||||||
|
basic_info: dict = None,
|
||||||
|
runtime_info: dict = None,
|
||||||
|
):
|
||||||
"""初始化"""
|
"""初始化"""
|
||||||
|
|
||||||
logging.debug("basic_info: %s, runtime_info: %s", basic_info, runtime_info)
|
logging.debug('basic_info: %s, runtime_info: %s', basic_info, runtime_info)
|
||||||
|
|
||||||
apigroup.APIGroup._basic_info = basic_info
|
apigroup.APIGroup._basic_info = basic_info
|
||||||
apigroup.APIGroup._runtime_info = runtime_info
|
apigroup.APIGroup._runtime_info = runtime_info
|
||||||
@@ -32,4 +38,3 @@ class V2CenterAPI:
|
|||||||
self.main = main.V2MainDataAPI(backend_url, ap)
|
self.main = main.V2MainDataAPI(backend_url, ap)
|
||||||
self.usage = usage.V2UsageDataAPI(backend_url, ap)
|
self.usage = usage.V2UsageDataAPI(backend_url, ap)
|
||||||
self.plugin = plugin.V2PluginDataAPI(backend_url, ap)
|
self.plugin = plugin.V2PluginDataAPI(backend_url, ap)
|
||||||
|
|
||||||
|
|||||||
+16
-12
@@ -16,6 +16,7 @@ identifier = {
|
|||||||
HOST_ID_FILE = os.path.expanduser('~/.langbot/host_id.json')
|
HOST_ID_FILE = os.path.expanduser('~/.langbot/host_id.json')
|
||||||
INSTANCE_ID_FILE = 'data/labels/instance_id.json'
|
INSTANCE_ID_FILE = 'data/labels/instance_id.json'
|
||||||
|
|
||||||
|
|
||||||
def init():
|
def init():
|
||||||
global identifier
|
global identifier
|
||||||
|
|
||||||
@@ -23,14 +24,11 @@ def init():
|
|||||||
os.mkdir(os.path.expanduser('~/.langbot'))
|
os.mkdir(os.path.expanduser('~/.langbot'))
|
||||||
|
|
||||||
if not os.path.exists(HOST_ID_FILE):
|
if not os.path.exists(HOST_ID_FILE):
|
||||||
new_host_id = 'host_'+str(uuid.uuid4())
|
new_host_id = 'host_' + str(uuid.uuid4())
|
||||||
new_host_create_ts = int(time.time())
|
new_host_create_ts = int(time.time())
|
||||||
|
|
||||||
with open(HOST_ID_FILE, 'w') as f:
|
with open(HOST_ID_FILE, 'w') as f:
|
||||||
json.dump({
|
json.dump({'host_id': new_host_id, 'host_create_ts': new_host_create_ts}, f)
|
||||||
'host_id': new_host_id,
|
|
||||||
'host_create_ts': new_host_create_ts
|
|
||||||
}, f)
|
|
||||||
|
|
||||||
identifier['host_id'] = new_host_id
|
identifier['host_id'] = new_host_id
|
||||||
identifier['host_create_ts'] = new_host_create_ts
|
identifier['host_create_ts'] = new_host_create_ts
|
||||||
@@ -52,19 +50,24 @@ def init():
|
|||||||
with open(INSTANCE_ID_FILE, 'r') as f:
|
with open(INSTANCE_ID_FILE, 'r') as f:
|
||||||
instance_id = json.load(f)
|
instance_id = json.load(f)
|
||||||
|
|
||||||
if instance_id['host_id'] != identifier['host_id']: # 如果实例 id 不是当前主机的,删除
|
if (
|
||||||
|
instance_id['host_id'] != identifier['host_id']
|
||||||
|
): # 如果实例 id 不是当前主机的,删除
|
||||||
os.remove(INSTANCE_ID_FILE)
|
os.remove(INSTANCE_ID_FILE)
|
||||||
|
|
||||||
if not os.path.exists(INSTANCE_ID_FILE):
|
if not os.path.exists(INSTANCE_ID_FILE):
|
||||||
new_instance_id = 'instance_'+str(uuid.uuid4())
|
new_instance_id = 'instance_' + str(uuid.uuid4())
|
||||||
new_instance_create_ts = int(time.time())
|
new_instance_create_ts = int(time.time())
|
||||||
|
|
||||||
with open(INSTANCE_ID_FILE, 'w') as f:
|
with open(INSTANCE_ID_FILE, 'w') as f:
|
||||||
json.dump({
|
json.dump(
|
||||||
'host_id': identifier['host_id'],
|
{
|
||||||
'instance_id': new_instance_id,
|
'host_id': identifier['host_id'],
|
||||||
'instance_create_ts': new_instance_create_ts
|
'instance_id': new_instance_id,
|
||||||
}, f)
|
'instance_create_ts': new_instance_create_ts,
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
identifier['instance_id'] = new_instance_id
|
identifier['instance_id'] = new_instance_id
|
||||||
identifier['instance_create_ts'] = new_instance_create_ts
|
identifier['instance_create_ts'] = new_instance_create_ts
|
||||||
@@ -80,6 +83,7 @@ def init():
|
|||||||
identifier['instance_id'] = loaded_instance_id
|
identifier['instance_id'] = loaded_instance_id
|
||||||
identifier['instance_create_ts'] = loaded_instance_create_ts
|
identifier['instance_create_ts'] = loaded_instance_create_ts
|
||||||
|
|
||||||
|
|
||||||
def print_out():
|
def print_out():
|
||||||
global identifier
|
global identifier
|
||||||
print(identifier)
|
print(identifier)
|
||||||
|
|||||||
+28
-29
@@ -3,17 +3,17 @@ from __future__ import annotations
|
|||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ..core import app, entities as core_entities
|
from ..core import app, entities as core_entities
|
||||||
from ..provider import entities as llm_entities
|
|
||||||
from . import entities, operator, errors
|
from . import entities, operator, errors
|
||||||
from ..config import manager as cfg_mgr
|
from ..utils import importutil
|
||||||
|
|
||||||
# 引入所有算子以便注册
|
# 引入所有算子以便注册
|
||||||
from .operators import func, plugin, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model
|
from . import operators
|
||||||
|
|
||||||
|
importutil.import_modules_in_pkg(operators)
|
||||||
|
|
||||||
|
|
||||||
class CommandManager:
|
class CommandManager:
|
||||||
"""命令管理器
|
"""命令管理器"""
|
||||||
"""
|
|
||||||
|
|
||||||
ap: app.Application
|
ap: app.Application
|
||||||
|
|
||||||
@@ -26,7 +26,6 @@ class CommandManager:
|
|||||||
self.ap = ap
|
self.ap = ap
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
|
|
||||||
# 设置各个类的路径
|
# 设置各个类的路径
|
||||||
def set_path(cls: operator.CommandOperator, ancestors: list[str]):
|
def set_path(cls: operator.CommandOperator, ancestors: list[str]):
|
||||||
cls.path = '.'.join(ancestors + [cls.name])
|
cls.path = '.'.join(ancestors + [cls.name])
|
||||||
@@ -41,14 +40,18 @@ class CommandManager:
|
|||||||
# 应用命令权限配置
|
# 应用命令权限配置
|
||||||
for cls in operator.preregistered_operators:
|
for cls in operator.preregistered_operators:
|
||||||
if cls.path in self.ap.instance_config.data['command']['privilege']:
|
if cls.path in self.ap.instance_config.data['command']['privilege']:
|
||||||
cls.lowest_privilege = self.ap.instance_config.data['command']['privilege'][cls.path]
|
cls.lowest_privilege = self.ap.instance_config.data['command'][
|
||||||
|
'privilege'
|
||||||
|
][cls.path]
|
||||||
|
|
||||||
# 实例化所有类
|
# 实例化所有类
|
||||||
self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators]
|
self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators]
|
||||||
|
|
||||||
# 设置所有类的子节点
|
# 设置所有类的子节点
|
||||||
for cmd in self.cmd_list:
|
for cmd in self.cmd_list:
|
||||||
cmd.children = [child for child in self.cmd_list if child.parent_class == cmd.__class__]
|
cmd.children = [
|
||||||
|
child for child in self.cmd_list if child.parent_class == cmd.__class__
|
||||||
|
]
|
||||||
|
|
||||||
# 初始化所有类
|
# 初始化所有类
|
||||||
for cmd in self.cmd_list:
|
for cmd in self.cmd_list:
|
||||||
@@ -58,27 +61,25 @@ class CommandManager:
|
|||||||
self,
|
self,
|
||||||
context: entities.ExecuteContext,
|
context: entities.ExecuteContext,
|
||||||
operator_list: list[operator.CommandOperator],
|
operator_list: list[operator.CommandOperator],
|
||||||
operator: operator.CommandOperator = None
|
operator: operator.CommandOperator = None,
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
"""执行命令
|
"""执行命令"""
|
||||||
"""
|
|
||||||
|
|
||||||
found = False
|
found = False
|
||||||
if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
|
if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
|
||||||
for oper in operator_list:
|
for oper in operator_list:
|
||||||
if (context.crt_params[0] == oper.name \
|
if (
|
||||||
or context.crt_params[0] in oper.alias) \
|
context.crt_params[0] == oper.name
|
||||||
and (oper.parent_class is None or oper.parent_class == operator.__class__):
|
or context.crt_params[0] in oper.alias
|
||||||
|
) and (
|
||||||
|
oper.parent_class is None or oper.parent_class == operator.__class__
|
||||||
|
):
|
||||||
found = True
|
found = True
|
||||||
|
|
||||||
context.crt_command = context.crt_params[0]
|
context.crt_command = context.crt_params[0]
|
||||||
context.crt_params = context.crt_params[1:]
|
context.crt_params = context.crt_params[1:]
|
||||||
|
|
||||||
async for ret in self._execute(
|
async for ret in self._execute(context, oper.children, oper):
|
||||||
context,
|
|
||||||
oper.children,
|
|
||||||
oper
|
|
||||||
):
|
|
||||||
yield ret
|
yield ret
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -96,19 +97,20 @@ class CommandManager:
|
|||||||
async for ret in operator.execute(context):
|
async for ret in operator.execute(context):
|
||||||
yield ret
|
yield ret
|
||||||
|
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
command_text: str,
|
command_text: str,
|
||||||
query: core_entities.Query,
|
query: core_entities.Query,
|
||||||
session: core_entities.Session
|
session: core_entities.Session,
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
"""执行命令
|
"""执行命令"""
|
||||||
"""
|
|
||||||
|
|
||||||
privilege = 1
|
privilege = 1
|
||||||
|
|
||||||
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
|
if (
|
||||||
|
f'{query.launcher_type.value}_{query.launcher_id}'
|
||||||
|
in self.ap.instance_config.data['admins']
|
||||||
|
):
|
||||||
privilege = 2
|
privilege = 2
|
||||||
|
|
||||||
ctx = entities.ExecuteContext(
|
ctx = entities.ExecuteContext(
|
||||||
@@ -119,11 +121,8 @@ class CommandManager:
|
|||||||
crt_command='',
|
crt_command='',
|
||||||
params=command_text.split(' '),
|
params=command_text.split(' '),
|
||||||
crt_params=command_text.split(' '),
|
crt_params=command_text.split(' '),
|
||||||
privilege=privilege
|
privilege=privilege,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for ret in self._execute(
|
async for ret in self._execute(ctx, self.cmd_list):
|
||||||
ctx,
|
|
||||||
self.cmd_list
|
|
||||||
):
|
|
||||||
yield ret
|
yield ret
|
||||||
|
|||||||
@@ -4,14 +4,13 @@ import typing
|
|||||||
|
|
||||||
import pydantic.v1 as pydantic
|
import pydantic.v1 as pydantic
|
||||||
|
|
||||||
from ..core import app, entities as core_entities
|
from ..core import entities as core_entities
|
||||||
from . import errors, operator
|
from . import errors
|
||||||
from ..platform.types import message as platform_message
|
from ..platform.types import message as platform_message
|
||||||
|
|
||||||
|
|
||||||
class CommandReturn(pydantic.BaseModel):
|
class CommandReturn(pydantic.BaseModel):
|
||||||
"""命令返回值
|
"""命令返回值"""
|
||||||
"""
|
|
||||||
|
|
||||||
text: typing.Optional[str] = None
|
text: typing.Optional[str] = None
|
||||||
"""文本
|
"""文本
|
||||||
@@ -24,7 +23,7 @@ class CommandReturn(pydantic.BaseModel):
|
|||||||
"""图片链接
|
"""图片链接
|
||||||
"""
|
"""
|
||||||
|
|
||||||
error: typing.Optional[errors.CommandError]= None
|
error: typing.Optional[errors.CommandError] = None
|
||||||
"""错误
|
"""错误
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -33,8 +32,7 @@ class CommandReturn(pydantic.BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ExecuteContext(pydantic.BaseModel):
|
class ExecuteContext(pydantic.BaseModel):
|
||||||
"""单次命令执行上下文
|
"""单次命令执行上下文"""
|
||||||
"""
|
|
||||||
|
|
||||||
query: core_entities.Query
|
query: core_entities.Query
|
||||||
"""本次消息的请求对象"""
|
"""本次消息的请求对象"""
|
||||||
|
|||||||
+4
-11
@@ -1,7 +1,4 @@
|
|||||||
|
|
||||||
|
|
||||||
class CommandError(Exception):
|
class CommandError(Exception):
|
||||||
|
|
||||||
def __init__(self, message: str = None):
|
def __init__(self, message: str = None):
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
@@ -10,24 +7,20 @@ class CommandError(Exception):
|
|||||||
|
|
||||||
|
|
||||||
class CommandNotFoundError(CommandError):
|
class CommandNotFoundError(CommandError):
|
||||||
|
|
||||||
def __init__(self, message: str = None):
|
def __init__(self, message: str = None):
|
||||||
super().__init__("未知命令: "+message)
|
super().__init__('未知命令: ' + message)
|
||||||
|
|
||||||
|
|
||||||
class CommandPrivilegeError(CommandError):
|
class CommandPrivilegeError(CommandError):
|
||||||
|
|
||||||
def __init__(self, message: str = None):
|
def __init__(self, message: str = None):
|
||||||
super().__init__("权限不足: "+message)
|
super().__init__('权限不足: ' + message)
|
||||||
|
|
||||||
|
|
||||||
class ParamNotEnoughError(CommandError):
|
class ParamNotEnoughError(CommandError):
|
||||||
|
|
||||||
def __init__(self, message: str = None):
|
def __init__(self, message: str = None):
|
||||||
super().__init__("参数不足: "+message)
|
super().__init__('参数不足: ' + message)
|
||||||
|
|
||||||
|
|
||||||
class CommandOperationError(CommandError):
|
class CommandOperationError(CommandError):
|
||||||
|
|
||||||
def __init__(self, message: str = None):
|
def __init__(self, message: str = None):
|
||||||
super().__init__("操作失败: "+message)
|
super().__init__('操作失败: ' + message)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import typing
|
import typing
|
||||||
import abc
|
import abc
|
||||||
|
|
||||||
from ..core import app, entities as core_entities
|
from ..core import app
|
||||||
from . import entities
|
from . import entities
|
||||||
|
|
||||||
|
|
||||||
@@ -13,11 +13,11 @@ preregistered_operators: list[typing.Type[CommandOperator]] = []
|
|||||||
|
|
||||||
def operator_class(
|
def operator_class(
|
||||||
name: str,
|
name: str,
|
||||||
help: str = "",
|
help: str = '',
|
||||||
usage: str = None,
|
usage: str = None,
|
||||||
alias: list[str] = [],
|
alias: list[str] = [],
|
||||||
privilege: int=1, # 1为普通用户,2为管理员
|
privilege: int = 1, # 1为普通用户,2为管理员
|
||||||
parent_class: typing.Type[CommandOperator] = None
|
parent_class: typing.Type[CommandOperator] = None,
|
||||||
) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]:
|
) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]:
|
||||||
"""命令类装饰器
|
"""命令类装饰器
|
||||||
|
|
||||||
@@ -96,8 +96,7 @@ class CommandOperator(metaclass=abc.ABCMeta):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
"""实现此方法以执行命令
|
"""实现此方法以执行命令
|
||||||
|
|
||||||
|
|||||||
@@ -2,32 +2,25 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .. import operator, entities, cmdmgr, errors
|
from .. import operator, entities, errors
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(name='cmd', help='显示命令列表', usage='!cmd\n!cmd <命令名称>')
|
||||||
name="cmd",
|
|
||||||
help='显示命令列表',
|
|
||||||
usage='!cmd\n!cmd <命令名称>'
|
|
||||||
)
|
|
||||||
class CmdOperator(operator.CommandOperator):
|
class CmdOperator(operator.CommandOperator):
|
||||||
"""命令列表
|
"""命令列表"""
|
||||||
"""
|
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
"""执行
|
"""执行"""
|
||||||
"""
|
|
||||||
if len(context.crt_params) == 0:
|
if len(context.crt_params) == 0:
|
||||||
reply_str = "当前所有命令: \n\n"
|
reply_str = '当前所有命令: \n\n'
|
||||||
|
|
||||||
for cmd in self.ap.cmd_mgr.cmd_list:
|
for cmd in self.ap.cmd_mgr.cmd_list:
|
||||||
if cmd.parent_class is None:
|
if cmd.parent_class is None:
|
||||||
reply_str += f"{cmd.name}: {cmd.help}\n"
|
reply_str += f'{cmd.name}: {cmd.help}\n'
|
||||||
|
|
||||||
reply_str += "\n使用 !cmd <命令名称> 查看命令的详细帮助"
|
reply_str += '\n使用 !cmd <命令名称> 查看命令的详细帮助'
|
||||||
|
|
||||||
yield entities.CommandReturn(text=reply_str.strip())
|
yield entities.CommandReturn(text=reply_str.strip())
|
||||||
|
|
||||||
@@ -37,14 +30,18 @@ class CmdOperator(operator.CommandOperator):
|
|||||||
cmd = None
|
cmd = None
|
||||||
|
|
||||||
for _cmd in self.ap.cmd_mgr.cmd_list:
|
for _cmd in self.ap.cmd_mgr.cmd_list:
|
||||||
if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (_cmd.parent_class is None):
|
if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (
|
||||||
|
_cmd.parent_class is None
|
||||||
|
):
|
||||||
cmd = _cmd
|
cmd = _cmd
|
||||||
break
|
break
|
||||||
|
|
||||||
if cmd is None:
|
if cmd is None:
|
||||||
yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandNotFoundError(cmd_name)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
reply_str = f"{cmd.name}: {cmd.help}\n\n"
|
reply_str = f'{cmd.name}: {cmd.help}\n\n'
|
||||||
reply_str += f"使用方法: \n{cmd.usage}"
|
reply_str += f'使用方法: \n{cmd.usage}'
|
||||||
|
|
||||||
yield entities.CommandReturn(text=reply_str.strip())
|
yield entities.CommandReturn(text=reply_str.strip())
|
||||||
|
|||||||
@@ -1,62 +1,60 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
import datetime
|
|
||||||
|
|
||||||
from .. import operator, entities, cmdmgr, errors
|
from .. import operator, entities, errors
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
name="del",
|
name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all'
|
||||||
help="删除当前会话的历史记录",
|
|
||||||
usage='!del <序号>\n!del all'
|
|
||||||
)
|
)
|
||||||
class DelOperator(operator.CommandOperator):
|
class DelOperator(operator.CommandOperator):
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
|
|
||||||
if context.session.conversations:
|
if context.session.conversations:
|
||||||
delete_index = 0
|
delete_index = 0
|
||||||
if len(context.crt_params) > 0:
|
if len(context.crt_params) > 0:
|
||||||
try:
|
try:
|
||||||
delete_index = int(context.crt_params[0])
|
delete_index = int(context.crt_params[0])
|
||||||
except:
|
except Exception:
|
||||||
yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数'))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandOperationError('索引必须是整数')
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if delete_index < 0 or delete_index >= len(context.session.conversations):
|
if delete_index < 0 or delete_index >= len(context.session.conversations):
|
||||||
yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围'))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandOperationError('索引超出范围')
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 倒序
|
# 倒序
|
||||||
to_delete_index = len(context.session.conversations)-1-delete_index
|
to_delete_index = len(context.session.conversations) - 1 - delete_index
|
||||||
|
|
||||||
if context.session.conversations[to_delete_index] == context.session.using_conversation:
|
if (
|
||||||
|
context.session.conversations[to_delete_index]
|
||||||
|
== context.session.using_conversation
|
||||||
|
):
|
||||||
context.session.using_conversation = None
|
context.session.using_conversation = None
|
||||||
|
|
||||||
del context.session.conversations[to_delete_index]
|
del context.session.conversations[to_delete_index]
|
||||||
|
|
||||||
yield entities.CommandReturn(text=f"已删除对话: {delete_index}")
|
yield entities.CommandReturn(text=f'已删除对话: {delete_index}')
|
||||||
else:
|
else:
|
||||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandOperationError('当前没有对话')
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
name="all",
|
name='all', help='删除此会话的所有历史记录', parent_class=DelOperator
|
||||||
help="删除此会话的所有历史记录",
|
|
||||||
parent_class=DelOperator
|
|
||||||
)
|
)
|
||||||
class DelAllOperator(operator.CommandOperator):
|
class DelAllOperator(operator.CommandOperator):
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
|
|
||||||
context.session.conversations = []
|
context.session.conversations = []
|
||||||
context.session.using_conversation = None
|
context.session.using_conversation = None
|
||||||
|
|
||||||
yield entities.CommandReturn(text="已删除所有对话")
|
yield entities.CommandReturn(text='已删除所有对话')
|
||||||
|
|||||||
@@ -1,16 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from .. import operator, entities, cmdmgr
|
from .. import operator, entities
|
||||||
from ...plugin import context as plugin_context
|
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func')
|
@operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func')
|
||||||
class FuncOperator(operator.CommandOperator):
|
class FuncOperator(operator.CommandOperator):
|
||||||
async def execute(
|
async def execute(
|
||||||
self, context: entities.ExecuteContext
|
self, context: entities.ExecuteContext
|
||||||
) -> AsyncGenerator[entities.CommandReturn, None]:
|
) -> AsyncGenerator[entities.CommandReturn, None]:
|
||||||
reply_str = "当前已启用的内容函数: \n\n"
|
reply_str = '当前已启用的内容函数: \n\n'
|
||||||
|
|
||||||
index = 1
|
index = 1
|
||||||
|
|
||||||
@@ -19,7 +18,7 @@ class FuncOperator(operator.CommandOperator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for func in all_functions:
|
for func in all_functions:
|
||||||
reply_str += "{}. {}:\n{}\n\n".format(
|
reply_str += '{}. {}:\n{}\n\n'.format(
|
||||||
index,
|
index,
|
||||||
func.name,
|
func.name,
|
||||||
func.description,
|
func.description,
|
||||||
|
|||||||
@@ -2,19 +2,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .. import operator, entities, cmdmgr, errors
|
from .. import operator, entities
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>')
|
||||||
name='help',
|
|
||||||
help='显示帮助',
|
|
||||||
usage='!help\n!help <命令名称>'
|
|
||||||
)
|
|
||||||
class HelpOperator(operator.CommandOperator):
|
class HelpOperator(operator.CommandOperator):
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接:https://langbot.app'
|
help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接:https://langbot.app'
|
||||||
|
|
||||||
|
|||||||
@@ -1,36 +1,43 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
import datetime
|
|
||||||
|
|
||||||
|
|
||||||
from .. import operator, entities, cmdmgr, errors
|
from .. import operator, entities, errors
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(name='last', help='切换到前一个对话', usage='!last')
|
||||||
name="last",
|
|
||||||
help="切换到前一个对话",
|
|
||||||
usage='!last'
|
|
||||||
)
|
|
||||||
class LastOperator(operator.CommandOperator):
|
class LastOperator(operator.CommandOperator):
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
|
|
||||||
if context.session.conversations:
|
if context.session.conversations:
|
||||||
# 找到当前会话的上一个会话
|
# 找到当前会话的上一个会话
|
||||||
for index in range(len(context.session.conversations)-1, -1, -1):
|
for index in range(len(context.session.conversations) - 1, -1, -1):
|
||||||
if context.session.conversations[index] == context.session.using_conversation:
|
if (
|
||||||
|
context.session.conversations[index]
|
||||||
|
== context.session.using_conversation
|
||||||
|
):
|
||||||
if index == 0:
|
if index == 0:
|
||||||
yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了'))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandOperationError('已经是第一个对话了')
|
||||||
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
context.session.using_conversation = context.session.conversations[index-1]
|
context.session.using_conversation = (
|
||||||
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
|
context.session.conversations[index - 1]
|
||||||
|
)
|
||||||
|
time_str = (
|
||||||
|
context.session.using_conversation.create_time.strftime(
|
||||||
|
'%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}")
|
yield entities.CommandReturn(
|
||||||
|
text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}'
|
||||||
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandOperationError('当前没有对话')
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,30 +1,26 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
import datetime
|
|
||||||
|
|
||||||
from .. import operator, entities, cmdmgr, errors
|
from .. import operator, entities, errors
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
name="list",
|
name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>'
|
||||||
help="列出此会话中的所有历史对话",
|
|
||||||
usage='!list\n!list <页码>'
|
|
||||||
)
|
)
|
||||||
class ListOperator(operator.CommandOperator):
|
class ListOperator(operator.CommandOperator):
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
|
|
||||||
page = 0
|
page = 0
|
||||||
|
|
||||||
if len(context.crt_params) > 0:
|
if len(context.crt_params) > 0:
|
||||||
try:
|
try:
|
||||||
page = int(context.crt_params[0]-1)
|
page = int(context.crt_params[0] - 1)
|
||||||
except:
|
except Exception:
|
||||||
yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数'))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandOperationError('页码应为整数')
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
record_per_page = 10
|
record_per_page = 10
|
||||||
@@ -36,21 +32,21 @@ class ListOperator(operator.CommandOperator):
|
|||||||
using_conv_index = 0
|
using_conv_index = 0
|
||||||
|
|
||||||
for conv in context.session.conversations[::-1]:
|
for conv in context.session.conversations[::-1]:
|
||||||
time_str = conv.create_time.strftime("%Y-%m-%d %H:%M:%S")
|
time_str = conv.create_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
|
||||||
if conv == context.session.using_conversation:
|
if conv == context.session.using_conversation:
|
||||||
using_conv_index = index
|
using_conv_index = index
|
||||||
|
|
||||||
if index >= page * record_per_page and index < (page + 1) * record_per_page:
|
if index >= page * record_per_page and index < (page + 1) * record_per_page:
|
||||||
content += f"{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else '无内容'}\n"
|
content += f'{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else "无内容"}\n'
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
if content == '':
|
if content == '':
|
||||||
content = '无'
|
content = '无'
|
||||||
else:
|
else:
|
||||||
if context.session.using_conversation is None:
|
if context.session.using_conversation is None:
|
||||||
content += "\n当前处于新会话"
|
content += '\n当前处于新会话'
|
||||||
else:
|
else:
|
||||||
content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else '无内容'}"
|
content += f'\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else "无内容"}'
|
||||||
|
|
||||||
yield entities.CommandReturn(text=f"第 {page + 1} 页 (时间倒序):\n{content}")
|
yield entities.CommandReturn(text=f'第 {page + 1} 页 (时间倒序):\n{content}')
|
||||||
|
|||||||
@@ -2,42 +2,44 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .. import operator, entities, cmdmgr, errors
|
from .. import operator, entities, errors
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
name="model",
|
name='model',
|
||||||
help='显示和切换模型列表',
|
help='显示和切换模型列表',
|
||||||
usage='!model\n!model show <模型名>\n!model set <模型名>',
|
usage='!model\n!model show <模型名>\n!model set <模型名>',
|
||||||
privilege=2
|
privilege=2,
|
||||||
)
|
)
|
||||||
class ModelOperator(operator.CommandOperator):
|
class ModelOperator(operator.CommandOperator):
|
||||||
"""Model命令"""
|
"""Model命令"""
|
||||||
|
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: entities.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
content = '模型列表:\n'
|
content = '模型列表:\n'
|
||||||
|
|
||||||
model_list = self.ap.model_mgr.model_list
|
model_list = self.ap.model_mgr.model_list
|
||||||
|
|
||||||
for model in model_list:
|
for model in model_list:
|
||||||
content += f"\n名称: {model.name}\n"
|
content += f'\n名称: {model.name}\n'
|
||||||
content += f"请求器: {model.requester.name}\n"
|
content += f'请求器: {model.requester.name}\n'
|
||||||
|
|
||||||
content += f"\n当前对话使用模型: {context.query.use_model.name}\n"
|
content += f'\n当前对话使用模型: {context.query.use_model.name}\n'
|
||||||
content += f"新对话默认使用模型: {self.ap.provider_cfg.data.get('model')}\n"
|
content += f'新对话默认使用模型: {self.ap.provider_cfg.data.get("model")}\n'
|
||||||
|
|
||||||
yield entities.CommandReturn(text=content.strip())
|
yield entities.CommandReturn(text=content.strip())
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
name="show",
|
name='show', help='显示模型详情', privilege=2, parent_class=ModelOperator
|
||||||
help='显示模型详情',
|
|
||||||
privilege=2,
|
|
||||||
parent_class=ModelOperator
|
|
||||||
)
|
)
|
||||||
class ModelShowOperator(operator.CommandOperator):
|
class ModelShowOperator(operator.CommandOperator):
|
||||||
"""Model Show命令"""
|
"""Model Show命令"""
|
||||||
|
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: entities.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
model_name = context.crt_params[0]
|
model_name = context.crt_params[0]
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
@@ -47,29 +49,31 @@ class ModelShowOperator(operator.CommandOperator):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}"))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandError(f'未找到模型 {model_name}')
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
content = f"模型详情\n"
|
content = '模型详情\n'
|
||||||
content += f"名称: {model.name}\n"
|
content += f'名称: {model.name}\n'
|
||||||
if model.model_name is not None:
|
if model.model_name is not None:
|
||||||
content += f"请求模型名称: {model.model_name}\n"
|
content += f'请求模型名称: {model.model_name}\n'
|
||||||
content += f"请求器: {model.requester.name}\n"
|
content += f'请求器: {model.requester.name}\n'
|
||||||
content += f"密钥组: {model.token_mgr.name}\n"
|
content += f'密钥组: {model.token_mgr.name}\n'
|
||||||
content += f"支持视觉: {model.vision_supported}\n"
|
content += f'支持视觉: {model.vision_supported}\n'
|
||||||
content += f"支持工具: {model.tool_call_supported}\n"
|
content += f'支持工具: {model.tool_call_supported}\n'
|
||||||
|
|
||||||
yield entities.CommandReturn(text=content.strip())
|
yield entities.CommandReturn(text=content.strip())
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
name="set",
|
name='set', help='设置默认使用模型', privilege=2, parent_class=ModelOperator
|
||||||
help='设置默认使用模型',
|
|
||||||
privilege=2,
|
|
||||||
parent_class=ModelOperator
|
|
||||||
)
|
)
|
||||||
class ModelSetOperator(operator.CommandOperator):
|
class ModelSetOperator(operator.CommandOperator):
|
||||||
"""Model Set命令"""
|
"""Model Set命令"""
|
||||||
|
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: entities.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
model_name = context.crt_params[0]
|
model_name = context.crt_params[0]
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
@@ -79,8 +83,12 @@ class ModelSetOperator(operator.CommandOperator):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}"))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandError(f'未找到模型 {model_name}')
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.ap.provider_cfg.data['model'] = model_name
|
self.ap.provider_cfg.data['model'] = model_name
|
||||||
await self.ap.provider_cfg.dump_config()
|
await self.ap.provider_cfg.dump_config()
|
||||||
yield entities.CommandReturn(text=f"已设置当前使用模型为 {model_name},重置会话以生效")
|
yield entities.CommandReturn(
|
||||||
|
text=f'已设置当前使用模型为 {model_name},重置会话以生效'
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,35 +1,42 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
import datetime
|
|
||||||
|
|
||||||
from .. import operator, entities, cmdmgr, errors
|
from .. import operator, entities, errors
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(name='next', help='切换到后一个对话', usage='!next')
|
||||||
name="next",
|
|
||||||
help="切换到后一个对话",
|
|
||||||
usage='!next'
|
|
||||||
)
|
|
||||||
class NextOperator(operator.CommandOperator):
|
class NextOperator(operator.CommandOperator):
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
|
|
||||||
if context.session.conversations:
|
if context.session.conversations:
|
||||||
# 找到当前会话的下一个会话
|
# 找到当前会话的下一个会话
|
||||||
for index in range(len(context.session.conversations)):
|
for index in range(len(context.session.conversations)):
|
||||||
if context.session.conversations[index] == context.session.using_conversation:
|
if (
|
||||||
if index == len(context.session.conversations)-1:
|
context.session.conversations[index]
|
||||||
yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了'))
|
== context.session.using_conversation
|
||||||
|
):
|
||||||
|
if index == len(context.session.conversations) - 1:
|
||||||
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandOperationError('已经是最后一个对话了')
|
||||||
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
context.session.using_conversation = context.session.conversations[index+1]
|
context.session.using_conversation = (
|
||||||
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
|
context.session.conversations[index + 1]
|
||||||
|
)
|
||||||
|
time_str = (
|
||||||
|
context.session.using_conversation.create_time.strftime(
|
||||||
|
'%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
yield entities.CommandReturn(text=f"已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}")
|
yield entities.CommandReturn(
|
||||||
|
text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}'
|
||||||
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandOperationError('当前没有对话')
|
||||||
|
)
|
||||||
|
|||||||
@@ -2,31 +2,32 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import typing
|
import typing
|
||||||
import traceback
|
|
||||||
|
|
||||||
import ollama
|
import ollama
|
||||||
from .. import operator, entities, errors
|
from .. import operator, entities, errors
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
name="ollama",
|
name='ollama',
|
||||||
help="ollama平台操作",
|
help='ollama平台操作',
|
||||||
usage="!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>"
|
usage='!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>',
|
||||||
)
|
)
|
||||||
class OllamaOperator(operator.CommandOperator):
|
class OllamaOperator(operator.CommandOperator):
|
||||||
async def execute(
|
async def execute(
|
||||||
self, context: entities.ExecuteContext
|
self, context: entities.ExecuteContext
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
try:
|
try:
|
||||||
content: str = '模型列表:\n'
|
content: str = '模型列表:\n'
|
||||||
model_list: list = ollama.list().get('models', [])
|
model_list: list = ollama.list().get('models', [])
|
||||||
for model in model_list:
|
for model in model_list:
|
||||||
content += f"名称: {model['name']}\n"
|
content += f'名称: {model["name"]}\n'
|
||||||
content += f"修改时间: {model['modified_at']}\n"
|
content += f'修改时间: {model["modified_at"]}\n'
|
||||||
content += f"大小: {bytes_to_mb(model['size'])}MB\n\n"
|
content += f'大小: {bytes_to_mb(model["size"])}MB\n\n'
|
||||||
yield entities.CommandReturn(text=f"{content.strip()}")
|
yield entities.CommandReturn(text=f'{content.strip()}')
|
||||||
except ollama.ResponseError as e:
|
except ollama.ResponseError:
|
||||||
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常"))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常')
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_mb(num_bytes):
|
def bytes_to_mb(num_bytes):
|
||||||
@@ -35,14 +36,11 @@ def bytes_to_mb(num_bytes):
|
|||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
name="show",
|
name='show', help='ollama模型详情', privilege=2, parent_class=OllamaOperator
|
||||||
help="ollama模型详情",
|
|
||||||
privilege=2,
|
|
||||||
parent_class=OllamaOperator
|
|
||||||
)
|
)
|
||||||
class OllamaShowOperator(operator.CommandOperator):
|
class OllamaShowOperator(operator.CommandOperator):
|
||||||
async def execute(
|
async def execute(
|
||||||
self, context: entities.ExecuteContext
|
self, context: entities.ExecuteContext
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
content: str = '模型详情:\n'
|
content: str = '模型详情:\n'
|
||||||
try:
|
try:
|
||||||
@@ -53,31 +51,36 @@ class OllamaShowOperator(operator.CommandOperator):
|
|||||||
for key in ['license', 'modelfile']:
|
for key in ['license', 'modelfile']:
|
||||||
show[key] = ignore_show
|
show[key] = ignore_show
|
||||||
|
|
||||||
for key in ['tokenizer.chat_template.rag', 'tokenizer.chat_template.tool_use']:
|
for key in [
|
||||||
|
'tokenizer.chat_template.rag',
|
||||||
|
'tokenizer.chat_template.tool_use',
|
||||||
|
]:
|
||||||
model_info[key] = ignore_show
|
model_info[key] = ignore_show
|
||||||
|
|
||||||
content += json.dumps(show, indent=4)
|
content += json.dumps(show, indent=4)
|
||||||
yield entities.CommandReturn(text=content.strip())
|
yield entities.CommandReturn(text=content.strip())
|
||||||
except ollama.ResponseError as e:
|
except ollama.ResponseError:
|
||||||
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型详情,请确认 Ollama 服务正常"))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandError('无法获取模型详情,请确认 Ollama 服务正常')
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
name="pull",
|
name='pull', help='ollama模型拉取', privilege=2, parent_class=OllamaOperator
|
||||||
help="ollama模型拉取",
|
|
||||||
privilege=2,
|
|
||||||
parent_class=OllamaOperator
|
|
||||||
)
|
)
|
||||||
class OllamaPullOperator(operator.CommandOperator):
|
class OllamaPullOperator(operator.CommandOperator):
|
||||||
async def execute(
|
async def execute(
|
||||||
self, context: entities.ExecuteContext
|
self, context: entities.ExecuteContext
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
try:
|
try:
|
||||||
model_list: list = ollama.list().get('models', [])
|
model_list: list = ollama.list().get('models', [])
|
||||||
if context.crt_params[0] in [model['name'] for model in model_list]:
|
if context.crt_params[0] in [model['name'] for model in model_list]:
|
||||||
yield entities.CommandReturn(text="模型已存在")
|
yield entities.CommandReturn(text='模型已存在')
|
||||||
return
|
return
|
||||||
except ollama.ResponseError as e:
|
except ollama.ResponseError:
|
||||||
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常"))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandError('无法获取模型列表,请确认 Ollama 服务正常')
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
on_progress: bool = False
|
on_progress: bool = False
|
||||||
@@ -99,23 +102,21 @@ class OllamaPullOperator(operator.CommandOperator):
|
|||||||
if percentage_completed > progress_count:
|
if percentage_completed > progress_count:
|
||||||
progress_count += 10
|
progress_count += 10
|
||||||
yield entities.CommandReturn(
|
yield entities.CommandReturn(
|
||||||
text=f"下载进度: {completed}/{total} ({percentage_completed:.2f}%)")
|
text=f'下载进度: {completed}/{total} ({percentage_completed:.2f}%)'
|
||||||
|
)
|
||||||
except ollama.ResponseError as e:
|
except ollama.ResponseError as e:
|
||||||
yield entities.CommandReturn(text=f"拉取失败: {e.error}")
|
yield entities.CommandReturn(text=f'拉取失败: {e.error}')
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
name="del",
|
name='del', help='ollama模型删除', privilege=2, parent_class=OllamaOperator
|
||||||
help="ollama模型删除",
|
|
||||||
privilege=2,
|
|
||||||
parent_class=OllamaOperator
|
|
||||||
)
|
)
|
||||||
class OllamaDelOperator(operator.CommandOperator):
|
class OllamaDelOperator(operator.CommandOperator):
|
||||||
async def execute(
|
async def execute(
|
||||||
self, context: entities.ExecuteContext
|
self, context: entities.ExecuteContext
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
try:
|
try:
|
||||||
ret: str = ollama.delete(model=context.crt_params[0])['status']
|
ret: str = ollama.delete(model=context.crt_params[0])['status']
|
||||||
except ollama.ResponseError as e:
|
except ollama.ResponseError as e:
|
||||||
ret = f"{e.error}"
|
ret = f'{e.error}'
|
||||||
yield entities.CommandReturn(text=ret)
|
yield entities.CommandReturn(text=ret)
|
||||||
|
|||||||
+101
-94
@@ -2,31 +2,30 @@ from __future__ import annotations
|
|||||||
import typing
|
import typing
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from .. import operator, entities, cmdmgr, errors
|
from .. import operator, entities, errors
|
||||||
from ...core import app
|
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
name="plugin",
|
name='plugin',
|
||||||
help="插件操作",
|
help='插件操作',
|
||||||
usage="!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>"
|
usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>',
|
||||||
)
|
)
|
||||||
class PluginOperator(operator.CommandOperator):
|
class PluginOperator(operator.CommandOperator):
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
|
|
||||||
plugin_list = self.ap.plugin_mgr.plugins()
|
plugin_list = self.ap.plugin_mgr.plugins()
|
||||||
reply_str = "所有插件({}):\n".format(len(plugin_list))
|
reply_str = '所有插件({}):\n'.format(len(plugin_list))
|
||||||
idx = 0
|
idx = 0
|
||||||
for plugin in plugin_list:
|
for plugin in plugin_list:
|
||||||
reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\
|
reply_str += '\n#{} {} {}\n{}\nv{}\n作者: {}\n'.format(
|
||||||
.format((idx+1), plugin.plugin_name,
|
(idx + 1),
|
||||||
"[已禁用]" if not plugin.enabled else "",
|
plugin.plugin_name,
|
||||||
plugin.plugin_description,
|
'[已禁用]' if not plugin.enabled else '',
|
||||||
plugin.plugin_version, plugin.plugin_author)
|
plugin.plugin_description,
|
||||||
|
plugin.plugin_version,
|
||||||
|
plugin.plugin_author,
|
||||||
|
)
|
||||||
|
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
@@ -34,48 +33,42 @@ class PluginOperator(operator.CommandOperator):
|
|||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
name="get",
|
name='get', help='安装插件', privilege=2, parent_class=PluginOperator
|
||||||
help="安装插件",
|
|
||||||
privilege=2,
|
|
||||||
parent_class=PluginOperator
|
|
||||||
)
|
)
|
||||||
class PluginGetOperator(operator.CommandOperator):
|
class PluginGetOperator(operator.CommandOperator):
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
|
|
||||||
if len(context.crt_params) == 0:
|
if len(context.crt_params) == 0:
|
||||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址'))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.ParamNotEnoughError('请提供插件仓库地址')
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
repo = context.crt_params[0]
|
repo = context.crt_params[0]
|
||||||
|
|
||||||
yield entities.CommandReturn(text="正在安装插件...")
|
yield entities.CommandReturn(text='正在安装插件...')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.ap.plugin_mgr.install_plugin(repo)
|
await self.ap.plugin_mgr.install_plugin(repo)
|
||||||
yield entities.CommandReturn(text="插件安装成功,请重启程序以加载插件")
|
yield entities.CommandReturn(text='插件安装成功,请重启程序以加载插件')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield entities.CommandReturn(error=errors.CommandError("插件安装失败: "+str(e)))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandError('插件安装失败: ' + str(e))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
name="update",
|
name='update', help='更新插件', privilege=2, parent_class=PluginOperator
|
||||||
help="更新插件",
|
|
||||||
privilege=2,
|
|
||||||
parent_class=PluginOperator
|
|
||||||
)
|
)
|
||||||
class PluginUpdateOperator(operator.CommandOperator):
|
class PluginUpdateOperator(operator.CommandOperator):
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
|
|
||||||
if len(context.crt_params) == 0:
|
if len(context.crt_params) == 0:
|
||||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.ParamNotEnoughError('请提供插件名称')
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
plugin_name = context.crt_params[0]
|
plugin_name = context.crt_params[0]
|
||||||
|
|
||||||
@@ -83,36 +76,34 @@ class PluginUpdateOperator(operator.CommandOperator):
|
|||||||
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
|
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
|
||||||
|
|
||||||
if plugin_container is not None:
|
if plugin_container is not None:
|
||||||
yield entities.CommandReturn(text="正在更新插件...")
|
yield entities.CommandReturn(text='正在更新插件...')
|
||||||
await self.ap.plugin_mgr.update_plugin(plugin_name)
|
await self.ap.plugin_mgr.update_plugin(plugin_name)
|
||||||
yield entities.CommandReturn(text="插件更新成功,请重启程序以加载插件")
|
yield entities.CommandReturn(
|
||||||
|
text='插件更新成功,请重启程序以加载插件'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: 未找到插件"))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandError('插件更新失败: 未找到插件')
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandError('插件更新失败: ' + str(e))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
name="all",
|
name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator
|
||||||
help="更新所有插件",
|
|
||||||
privilege=2,
|
|
||||||
parent_class=PluginUpdateOperator
|
|
||||||
)
|
)
|
||||||
class PluginUpdateAllOperator(operator.CommandOperator):
|
class PluginUpdateAllOperator(operator.CommandOperator):
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
plugins = [
|
plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()]
|
||||||
p.plugin_name
|
|
||||||
for p in self.ap.plugin_mgr.plugins()
|
|
||||||
]
|
|
||||||
|
|
||||||
if plugins:
|
if plugins:
|
||||||
yield entities.CommandReturn(text="正在更新插件...")
|
yield entities.CommandReturn(text='正在更新插件...')
|
||||||
updated = []
|
updated = []
|
||||||
try:
|
try:
|
||||||
for plugin_name in plugins:
|
for plugin_name in plugins:
|
||||||
@@ -120,30 +111,32 @@ class PluginUpdateAllOperator(operator.CommandOperator):
|
|||||||
updated.append(plugin_name)
|
updated.append(plugin_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
|
yield entities.CommandReturn(
|
||||||
yield entities.CommandReturn(text="已更新插件: {}".format(", ".join(updated)))
|
error=errors.CommandError('插件更新失败: ' + str(e))
|
||||||
|
)
|
||||||
|
yield entities.CommandReturn(
|
||||||
|
text='已更新插件: {}'.format(', '.join(updated))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
yield entities.CommandReturn(text="没有可更新的插件")
|
yield entities.CommandReturn(text='没有可更新的插件')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandError('插件更新失败: ' + str(e))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
name="del",
|
name='del', help='删除插件', privilege=2, parent_class=PluginOperator
|
||||||
help="删除插件",
|
|
||||||
privilege=2,
|
|
||||||
parent_class=PluginOperator
|
|
||||||
)
|
)
|
||||||
class PluginDelOperator(operator.CommandOperator):
|
class PluginDelOperator(operator.CommandOperator):
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
|
|
||||||
if len(context.crt_params) == 0:
|
if len(context.crt_params) == 0:
|
||||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.ParamNotEnoughError('请提供插件名称')
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
plugin_name = context.crt_params[0]
|
plugin_name = context.crt_params[0]
|
||||||
|
|
||||||
@@ -151,67 +144,81 @@ class PluginDelOperator(operator.CommandOperator):
|
|||||||
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
|
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
|
||||||
|
|
||||||
if plugin_container is not None:
|
if plugin_container is not None:
|
||||||
yield entities.CommandReturn(text="正在删除插件...")
|
yield entities.CommandReturn(text='正在删除插件...')
|
||||||
await self.ap.plugin_mgr.uninstall_plugin(plugin_name)
|
await self.ap.plugin_mgr.uninstall_plugin(plugin_name)
|
||||||
yield entities.CommandReturn(text="插件删除成功,请重启程序以加载插件")
|
yield entities.CommandReturn(
|
||||||
|
text='插件删除成功,请重启程序以加载插件'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: 未找到插件"))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandError('插件删除失败: 未找到插件')
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: "+str(e)))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandError('插件删除失败: ' + str(e))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
name="on",
|
name='on', help='启用插件', privilege=2, parent_class=PluginOperator
|
||||||
help="启用插件",
|
|
||||||
privilege=2,
|
|
||||||
parent_class=PluginOperator
|
|
||||||
)
|
)
|
||||||
class PluginEnableOperator(operator.CommandOperator):
|
class PluginEnableOperator(operator.CommandOperator):
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
|
|
||||||
if len(context.crt_params) == 0:
|
if len(context.crt_params) == 0:
|
||||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.ParamNotEnoughError('请提供插件名称')
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
plugin_name = context.crt_params[0]
|
plugin_name = context.crt_params[0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True):
|
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True):
|
||||||
yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name))
|
yield entities.CommandReturn(
|
||||||
|
text='已启用插件: {}'.format(plugin_name)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandError(
|
||||||
|
'插件状态修改失败: 未找到插件 {}'.format(plugin_name)
|
||||||
|
)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e)))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandError('插件状态修改失败: ' + str(e))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
name="off",
|
name='off', help='禁用插件', privilege=2, parent_class=PluginOperator
|
||||||
help="禁用插件",
|
|
||||||
privilege=2,
|
|
||||||
parent_class=PluginOperator
|
|
||||||
)
|
)
|
||||||
class PluginDisableOperator(operator.CommandOperator):
|
class PluginDisableOperator(operator.CommandOperator):
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
|
|
||||||
if len(context.crt_params) == 0:
|
if len(context.crt_params) == 0:
|
||||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.ParamNotEnoughError('请提供插件名称')
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
plugin_name = context.crt_params[0]
|
plugin_name = context.crt_params[0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False):
|
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False):
|
||||||
yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name))
|
yield entities.CommandReturn(
|
||||||
|
text='已禁用插件: {}'.format(plugin_name)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandError(
|
||||||
|
'插件状态修改失败: 未找到插件 {}'.format(plugin_name)
|
||||||
|
)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e)))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandError('插件状态修改失败: ' + str(e))
|
||||||
|
)
|
||||||
|
|||||||
@@ -2,28 +2,23 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .. import operator, entities, cmdmgr, errors
|
from .. import operator, entities, errors
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt')
|
||||||
name="prompt",
|
|
||||||
help="查看当前对话的前文",
|
|
||||||
usage='!prompt'
|
|
||||||
)
|
|
||||||
class PromptOperator(operator.CommandOperator):
|
class PromptOperator(operator.CommandOperator):
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
"""执行
|
"""执行"""
|
||||||
"""
|
|
||||||
if context.session.using_conversation is None:
|
if context.session.using_conversation is None:
|
||||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandOperationError('当前没有对话')
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
reply_str = '当前对话所有内容:\n\n'
|
reply_str = '当前对话所有内容:\n\n'
|
||||||
|
|
||||||
for msg in context.session.using_conversation.messages:
|
for msg in context.session.using_conversation.messages:
|
||||||
reply_str += f"{msg.role}: {msg.content}\n"
|
reply_str += f'{msg.role}: {msg.content}\n'
|
||||||
|
|
||||||
yield entities.CommandReturn(text=reply_str)
|
yield entities.CommandReturn(text=reply_str)
|
||||||
@@ -2,23 +2,19 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .. import operator, entities, cmdmgr, errors
|
from .. import operator, entities, errors
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
name="resend",
|
name='resend', help='重发当前会话的最后一条消息', usage='!resend'
|
||||||
help="重发当前会话的最后一条消息",
|
|
||||||
usage='!resend'
|
|
||||||
)
|
)
|
||||||
class ResendOperator(operator.CommandOperator):
|
class ResendOperator(operator.CommandOperator):
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
# 回滚到最后一条用户message前
|
# 回滚到最后一条用户message前
|
||||||
if context.session.using_conversation is None:
|
if context.session.using_conversation is None:
|
||||||
yield entities.CommandReturn(error=errors.CommandError("当前没有对话"))
|
yield entities.CommandReturn(error=errors.CommandError('当前没有对话'))
|
||||||
else:
|
else:
|
||||||
conv_msg = context.session.using_conversation.messages
|
conv_msg = context.session.using_conversation.messages
|
||||||
|
|
||||||
@@ -31,4 +27,4 @@ class ResendOperator(operator.CommandOperator):
|
|||||||
conv_msg.pop()
|
conv_msg.pop()
|
||||||
|
|
||||||
# 不重发了,提示用户已删除就行了
|
# 不重发了,提示用户已删除就行了
|
||||||
yield entities.CommandReturn(text="已删除最后一次请求记录")
|
yield entities.CommandReturn(text='已删除最后一次请求记录')
|
||||||
|
|||||||
@@ -2,22 +2,15 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .. import operator, entities, cmdmgr, errors
|
from .. import operator, entities
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(name='reset', help='重置当前会话', usage='!reset')
|
||||||
name="reset",
|
|
||||||
help="重置当前会话",
|
|
||||||
usage='!reset'
|
|
||||||
)
|
|
||||||
class ResetOperator(operator.CommandOperator):
|
class ResetOperator(operator.CommandOperator):
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
"""执行
|
"""执行"""
|
||||||
"""
|
|
||||||
context.session.using_conversation = None
|
context.session.using_conversation = None
|
||||||
|
|
||||||
yield entities.CommandReturn(text="已重置当前会话")
|
yield entities.CommandReturn(text='已重置当前会话')
|
||||||
|
|||||||
@@ -3,28 +3,22 @@ from __future__ import annotations
|
|||||||
import typing
|
import typing
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from .. import operator, entities, cmdmgr, errors
|
from .. import operator, entities, errors
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(name='update', help='更新程序', usage='!update', privilege=2)
|
||||||
name="update",
|
|
||||||
help="更新程序",
|
|
||||||
usage='!update',
|
|
||||||
privilege=2
|
|
||||||
)
|
|
||||||
class UpdateCommand(operator.CommandOperator):
|
class UpdateCommand(operator.CommandOperator):
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield entities.CommandReturn(text="正在进行更新...")
|
yield entities.CommandReturn(text='正在进行更新...')
|
||||||
if await self.ap.ver_mgr.update_all():
|
if await self.ap.ver_mgr.update_all():
|
||||||
yield entities.CommandReturn(text="更新完成,请重启程序以应用更新")
|
yield entities.CommandReturn(text='更新完成,请重启程序以应用更新')
|
||||||
else:
|
else:
|
||||||
yield entities.CommandReturn(text="当前已是最新版本")
|
yield entities.CommandReturn(text='当前已是最新版本')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield entities.CommandReturn(error=errors.CommandError("更新失败: "+str(e)))
|
yield entities.CommandReturn(
|
||||||
|
error=errors.CommandError('更新失败: ' + str(e))
|
||||||
|
)
|
||||||
|
|||||||
@@ -2,26 +2,20 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .. import operator, cmdmgr, entities, errors
|
from .. import operator, entities
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(name='version', help='显示版本信息', usage='!version')
|
||||||
name="version",
|
|
||||||
help="显示版本信息",
|
|
||||||
usage='!version'
|
|
||||||
)
|
|
||||||
class VersionCommand(operator.CommandOperator):
|
class VersionCommand(operator.CommandOperator):
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self, context: entities.ExecuteContext
|
||||||
context: entities.ExecuteContext
|
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||||
reply_str = f"当前版本: \n{self.ap.ver_mgr.get_current_version()}"
|
reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if await self.ap.ver_mgr.is_new_version_available():
|
if await self.ap.ver_mgr.is_new_version_available():
|
||||||
reply_str += "\n\n有新版本可用。"
|
reply_str += '\n\n有新版本可用。'
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
yield entities.CommandReturn(text=reply_str.strip())
|
yield entities.CommandReturn(text=reply_str.strip())
|
||||||
+12
-11
@@ -9,7 +9,10 @@ class JSONConfigFile(file_model.ConfigFile):
|
|||||||
"""JSON配置文件"""
|
"""JSON配置文件"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config_file_name: str, template_file_name: str = None, template_data: dict = None
|
self,
|
||||||
|
config_file_name: str,
|
||||||
|
template_file_name: str = None,
|
||||||
|
template_data: dict = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.config_file_name = config_file_name
|
self.config_file_name = config_file_name
|
||||||
self.template_file_name = template_file_name
|
self.template_file_name = template_file_name
|
||||||
@@ -22,28 +25,26 @@ class JSONConfigFile(file_model.ConfigFile):
|
|||||||
if self.template_file_name is not None:
|
if self.template_file_name is not None:
|
||||||
shutil.copyfile(self.template_file_name, self.config_file_name)
|
shutil.copyfile(self.template_file_name, self.config_file_name)
|
||||||
elif self.template_data is not None:
|
elif self.template_data is not None:
|
||||||
with open(self.config_file_name, "w", encoding="utf-8") as f:
|
with open(self.config_file_name, 'w', encoding='utf-8') as f:
|
||||||
json.dump(self.template_data, f, indent=4, ensure_ascii=False)
|
json.dump(self.template_data, f, indent=4, ensure_ascii=False)
|
||||||
else:
|
else:
|
||||||
raise ValueError("template_file_name or template_data must be provided")
|
raise ValueError('template_file_name or template_data must be provided')
|
||||||
|
|
||||||
async def load(self, completion: bool=True) -> dict:
|
|
||||||
|
|
||||||
|
async def load(self, completion: bool = True) -> dict:
|
||||||
if not self.exists():
|
if not self.exists():
|
||||||
await self.create()
|
await self.create()
|
||||||
|
|
||||||
if self.template_file_name is not None:
|
if self.template_file_name is not None:
|
||||||
with open(self.template_file_name, "r", encoding="utf-8") as f:
|
with open(self.template_file_name, 'r', encoding='utf-8') as f:
|
||||||
self.template_data = json.load(f)
|
self.template_data = json.load(f)
|
||||||
|
|
||||||
with open(self.config_file_name, "r", encoding="utf-8") as f:
|
with open(self.config_file_name, 'r', encoding='utf-8') as f:
|
||||||
try:
|
try:
|
||||||
cfg = json.load(f)
|
cfg = json.load(f)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise Exception(f"配置文件 {self.config_file_name} 语法错误: {e}")
|
raise Exception(f'配置文件 {self.config_file_name} 语法错误: {e}')
|
||||||
|
|
||||||
if completion:
|
if completion:
|
||||||
|
|
||||||
for key in self.template_data:
|
for key in self.template_data:
|
||||||
if key not in cfg:
|
if key not in cfg:
|
||||||
cfg[key] = self.template_data[key]
|
cfg[key] = self.template_data[key]
|
||||||
@@ -51,9 +52,9 @@ class JSONConfigFile(file_model.ConfigFile):
|
|||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
async def save(self, cfg: dict):
|
async def save(self, cfg: dict):
|
||||||
with open(self.config_file_name, "w", encoding="utf-8") as f:
|
with open(self.config_file_name, 'w', encoding='utf-8') as f:
|
||||||
json.dump(cfg, f, indent=4, ensure_ascii=False)
|
json.dump(cfg, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
def save_sync(self, cfg: dict):
|
def save_sync(self, cfg: dict):
|
||||||
with open(self.config_file_name, "w", encoding="utf-8") as f:
|
with open(self.config_file_name, 'w', encoding='utf-8') as f:
|
||||||
json.dump(cfg, f, indent=4, ensure_ascii=False)
|
json.dump(cfg, f, indent=4, ensure_ascii=False)
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class PythonModuleConfigFile(file_model.ConfigFile):
|
|||||||
async def create(self):
|
async def create(self):
|
||||||
shutil.copyfile(self.template_file_name, self.config_file_name)
|
shutil.copyfile(self.template_file_name, self.config_file_name)
|
||||||
|
|
||||||
async def load(self, completion: bool=True) -> dict:
|
async def load(self, completion: bool = True) -> dict:
|
||||||
module_name = os.path.splitext(os.path.basename(self.config_file_name))[0]
|
module_name = os.path.splitext(os.path.basename(self.config_file_name))[0]
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
|||||||
+12
-11
@@ -9,7 +9,10 @@ class YAMLConfigFile(file_model.ConfigFile):
|
|||||||
"""YAML配置文件"""
|
"""YAML配置文件"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config_file_name: str, template_file_name: str = None, template_data: dict = None
|
self,
|
||||||
|
config_file_name: str,
|
||||||
|
template_file_name: str = None,
|
||||||
|
template_data: dict = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.config_file_name = config_file_name
|
self.config_file_name = config_file_name
|
||||||
self.template_file_name = template_file_name
|
self.template_file_name = template_file_name
|
||||||
@@ -22,28 +25,26 @@ class YAMLConfigFile(file_model.ConfigFile):
|
|||||||
if self.template_file_name is not None:
|
if self.template_file_name is not None:
|
||||||
shutil.copyfile(self.template_file_name, self.config_file_name)
|
shutil.copyfile(self.template_file_name, self.config_file_name)
|
||||||
elif self.template_data is not None:
|
elif self.template_data is not None:
|
||||||
with open(self.config_file_name, "w", encoding="utf-8") as f:
|
with open(self.config_file_name, 'w', encoding='utf-8') as f:
|
||||||
yaml.dump(self.template_data, f, indent=4, allow_unicode=True)
|
yaml.dump(self.template_data, f, indent=4, allow_unicode=True)
|
||||||
else:
|
else:
|
||||||
raise ValueError("template_file_name or template_data must be provided")
|
raise ValueError('template_file_name or template_data must be provided')
|
||||||
|
|
||||||
async def load(self, completion: bool=True) -> dict:
|
|
||||||
|
|
||||||
|
async def load(self, completion: bool = True) -> dict:
|
||||||
if not self.exists():
|
if not self.exists():
|
||||||
await self.create()
|
await self.create()
|
||||||
|
|
||||||
if self.template_file_name is not None:
|
if self.template_file_name is not None:
|
||||||
with open(self.template_file_name, "r", encoding="utf-8") as f:
|
with open(self.template_file_name, 'r', encoding='utf-8') as f:
|
||||||
self.template_data = yaml.load(f, Loader=yaml.FullLoader)
|
self.template_data = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
with open(self.config_file_name, "r", encoding="utf-8") as f:
|
with open(self.config_file_name, 'r', encoding='utf-8') as f:
|
||||||
try:
|
try:
|
||||||
cfg = yaml.load(f, Loader=yaml.FullLoader)
|
cfg = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
except yaml.YAMLError as e:
|
except yaml.YAMLError as e:
|
||||||
raise Exception(f"配置文件 {self.config_file_name} 语法错误: {e}")
|
raise Exception(f'配置文件 {self.config_file_name} 语法错误: {e}')
|
||||||
|
|
||||||
if completion:
|
if completion:
|
||||||
|
|
||||||
for key in self.template_data:
|
for key in self.template_data:
|
||||||
if key not in cfg:
|
if key not in cfg:
|
||||||
cfg[key] = self.template_data[key]
|
cfg[key] = self.template_data[key]
|
||||||
@@ -51,9 +52,9 @@ class YAMLConfigFile(file_model.ConfigFile):
|
|||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
async def save(self, cfg: dict):
|
async def save(self, cfg: dict):
|
||||||
with open(self.config_file_name, "w", encoding="utf-8") as f:
|
with open(self.config_file_name, 'w', encoding='utf-8') as f:
|
||||||
yaml.dump(cfg, f, indent=4, allow_unicode=True)
|
yaml.dump(cfg, f, indent=4, allow_unicode=True)
|
||||||
|
|
||||||
def save_sync(self, cfg: dict):
|
def save_sync(self, cfg: dict):
|
||||||
with open(self.config_file_name, "w", encoding="utf-8") as f:
|
with open(self.config_file_name, 'w', encoding='utf-8') as f:
|
||||||
yaml.dump(cfg, f, indent=4, allow_unicode=True)
|
yaml.dump(cfg, f, indent=4, allow_unicode=True)
|
||||||
+19
-18
@@ -31,7 +31,7 @@ class ConfigManager:
|
|||||||
self.file = cfg_file
|
self.file = cfg_file
|
||||||
self.data = {}
|
self.data = {}
|
||||||
|
|
||||||
async def load_config(self, completion: bool=True):
|
async def load_config(self, completion: bool = True):
|
||||||
self.data = await self.file.load(completion=completion)
|
self.data = await self.file.load(completion=completion)
|
||||||
|
|
||||||
async def dump_config(self):
|
async def dump_config(self):
|
||||||
@@ -41,7 +41,9 @@ class ConfigManager:
|
|||||||
self.file.save_sync(self.data)
|
self.file.save_sync(self.data)
|
||||||
|
|
||||||
|
|
||||||
async def load_python_module_config(config_name: str, template_name: str, completion: bool=True) -> ConfigManager:
|
async def load_python_module_config(
|
||||||
|
config_name: str, template_name: str, completion: bool = True
|
||||||
|
) -> ConfigManager:
|
||||||
"""加载Python模块配置文件
|
"""加载Python模块配置文件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -52,10 +54,7 @@ async def load_python_module_config(config_name: str, template_name: str, comple
|
|||||||
Returns:
|
Returns:
|
||||||
ConfigManager: 配置文件管理器
|
ConfigManager: 配置文件管理器
|
||||||
"""
|
"""
|
||||||
cfg_inst = pymodule.PythonModuleConfigFile(
|
cfg_inst = pymodule.PythonModuleConfigFile(config_name, template_name)
|
||||||
config_name,
|
|
||||||
template_name
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg_mgr = ConfigManager(cfg_inst)
|
cfg_mgr = ConfigManager(cfg_inst)
|
||||||
await cfg_mgr.load_config(completion=completion)
|
await cfg_mgr.load_config(completion=completion)
|
||||||
@@ -63,7 +62,12 @@ async def load_python_module_config(config_name: str, template_name: str, comple
|
|||||||
return cfg_mgr
|
return cfg_mgr
|
||||||
|
|
||||||
|
|
||||||
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager:
|
async def load_json_config(
|
||||||
|
config_name: str,
|
||||||
|
template_name: str = None,
|
||||||
|
template_data: dict = None,
|
||||||
|
completion: bool = True,
|
||||||
|
) -> ConfigManager:
|
||||||
"""加载JSON配置文件
|
"""加载JSON配置文件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -72,11 +76,7 @@ async def load_json_config(config_name: str, template_name: str=None, template_d
|
|||||||
template_data (dict): 模板数据
|
template_data (dict): 模板数据
|
||||||
completion (bool): 是否自动补全内存中的配置文件
|
completion (bool): 是否自动补全内存中的配置文件
|
||||||
"""
|
"""
|
||||||
cfg_inst = json_file.JSONConfigFile(
|
cfg_inst = json_file.JSONConfigFile(config_name, template_name, template_data)
|
||||||
config_name,
|
|
||||||
template_name,
|
|
||||||
template_data
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg_mgr = ConfigManager(cfg_inst)
|
cfg_mgr = ConfigManager(cfg_inst)
|
||||||
await cfg_mgr.load_config(completion=completion)
|
await cfg_mgr.load_config(completion=completion)
|
||||||
@@ -84,7 +84,12 @@ async def load_json_config(config_name: str, template_name: str=None, template_d
|
|||||||
return cfg_mgr
|
return cfg_mgr
|
||||||
|
|
||||||
|
|
||||||
async def load_yaml_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager:
|
async def load_yaml_config(
|
||||||
|
config_name: str,
|
||||||
|
template_name: str = None,
|
||||||
|
template_data: dict = None,
|
||||||
|
completion: bool = True,
|
||||||
|
) -> ConfigManager:
|
||||||
"""加载YAML配置文件
|
"""加载YAML配置文件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -96,11 +101,7 @@ async def load_yaml_config(config_name: str, template_name: str=None, template_d
|
|||||||
Returns:
|
Returns:
|
||||||
ConfigManager: 配置文件管理器
|
ConfigManager: 配置文件管理器
|
||||||
"""
|
"""
|
||||||
cfg_inst = yaml_file.YAMLConfigFile(
|
cfg_inst = yaml_file.YAMLConfigFile(config_name, template_name, template_data)
|
||||||
config_name,
|
|
||||||
template_name,
|
|
||||||
template_data
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg_mgr = ConfigManager(cfg_inst)
|
cfg_mgr = ConfigManager(cfg_inst)
|
||||||
await cfg_mgr.load_config(completion=completion)
|
await cfg_mgr.load_config(completion=completion)
|
||||||
|
|||||||
+1
-1
@@ -22,7 +22,7 @@ class ConfigFile(metaclass=abc.ABCMeta):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def load(self, completion: bool=True) -> dict:
|
async def load(self, completion: bool = True) -> dict:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
|||||||
+44
-18
@@ -2,9 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
|
||||||
import traceback
|
import traceback
|
||||||
import enum
|
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -29,7 +27,6 @@ from ..discover import engine as discover_engine
|
|||||||
from ..utils import logcache, ip
|
from ..utils import logcache, ip
|
||||||
from . import taskmgr
|
from . import taskmgr
|
||||||
from . import entities as core_entities
|
from . import entities as core_entities
|
||||||
from .bootutils import config
|
|
||||||
|
|
||||||
|
|
||||||
class Application:
|
class Application:
|
||||||
@@ -123,33 +120,55 @@ class Application:
|
|||||||
async def run(self):
|
async def run(self):
|
||||||
try:
|
try:
|
||||||
await self.plugin_mgr.initialize_plugins()
|
await self.plugin_mgr.initialize_plugins()
|
||||||
|
|
||||||
# 后续可能会允许动态重启其他任务
|
# 后续可能会允许动态重启其他任务
|
||||||
# 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程
|
# 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程
|
||||||
async def never_ending():
|
async def never_ending():
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM])
|
self.task_mgr.create_task(
|
||||||
self.task_mgr.create_task(self.ctrl.run(), name="query-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION])
|
self.platform_mgr.run(),
|
||||||
self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION])
|
name='platform-manager',
|
||||||
self.task_mgr.create_task(never_ending(), name="never-ending-task", scopes=[core_entities.LifecycleControlScope.APPLICATION])
|
scopes=[
|
||||||
|
core_entities.LifecycleControlScope.APPLICATION,
|
||||||
|
core_entities.LifecycleControlScope.PLATFORM,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.task_mgr.create_task(
|
||||||
|
self.ctrl.run(),
|
||||||
|
name='query-controller',
|
||||||
|
scopes=[core_entities.LifecycleControlScope.APPLICATION],
|
||||||
|
)
|
||||||
|
self.task_mgr.create_task(
|
||||||
|
self.http_ctrl.run(),
|
||||||
|
name='http-api-controller',
|
||||||
|
scopes=[core_entities.LifecycleControlScope.APPLICATION],
|
||||||
|
)
|
||||||
|
self.task_mgr.create_task(
|
||||||
|
never_ending(),
|
||||||
|
name='never-ending-task',
|
||||||
|
scopes=[core_entities.LifecycleControlScope.APPLICATION],
|
||||||
|
)
|
||||||
|
|
||||||
await self.print_web_access_info()
|
await self.print_web_access_info()
|
||||||
await self.task_mgr.wait_all()
|
await self.task_mgr.wait_all()
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"应用运行致命异常: {e}")
|
self.logger.error(f'应用运行致命异常: {e}')
|
||||||
self.logger.debug(f"Traceback: {traceback.format_exc()}")
|
self.logger.debug(f'Traceback: {traceback.format_exc()}')
|
||||||
|
|
||||||
async def print_web_access_info(self):
|
async def print_web_access_info(self):
|
||||||
"""打印访问 webui 的提示"""
|
"""打印访问 webui 的提示"""
|
||||||
|
|
||||||
if not os.path.exists(os.path.join(".", "web/out")):
|
if not os.path.exists(os.path.join('.', 'web/out')):
|
||||||
self.logger.warning("WebUI 文件缺失,请根据文档获取:https://docs.langbot.app/webui/intro.html")
|
self.logger.warning(
|
||||||
|
'WebUI 文件缺失,请根据文档获取:https://docs.langbot.app/webui/intro.html'
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
host_ip = "127.0.0.1"
|
host_ip = '127.0.0.1'
|
||||||
|
|
||||||
public_ip = await ip.get_myip()
|
public_ip = await ip.get_myip()
|
||||||
|
|
||||||
@@ -170,7 +189,7 @@ class Application:
|
|||||||
🤯 WebUI 仍处于 Beta 测试阶段,如有问题或建议请反馈到 https://github.com/RockChinQ/LangBot/issues
|
🤯 WebUI 仍处于 Beta 测试阶段,如有问题或建议请反馈到 https://github.com/RockChinQ/LangBot/issues
|
||||||
=======================================
|
=======================================
|
||||||
""".strip()
|
""".strip()
|
||||||
for line in tips.split("\n"):
|
for line in tips.split('\n'):
|
||||||
self.logger.info(line)
|
self.logger.info(line)
|
||||||
|
|
||||||
async def reload(
|
async def reload(
|
||||||
@@ -179,21 +198,28 @@ class Application:
|
|||||||
):
|
):
|
||||||
match scope:
|
match scope:
|
||||||
case core_entities.LifecycleControlScope.PLATFORM.value:
|
case core_entities.LifecycleControlScope.PLATFORM.value:
|
||||||
self.logger.info("执行热重载 scope="+scope)
|
self.logger.info('执行热重载 scope=' + scope)
|
||||||
await self.platform_mgr.shutdown()
|
await self.platform_mgr.shutdown()
|
||||||
|
|
||||||
self.platform_mgr = im_mgr.PlatformManager(self)
|
self.platform_mgr = im_mgr.PlatformManager(self)
|
||||||
|
|
||||||
await self.platform_mgr.initialize()
|
await self.platform_mgr.initialize()
|
||||||
|
|
||||||
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM])
|
self.task_mgr.create_task(
|
||||||
|
self.platform_mgr.run(),
|
||||||
|
name='platform-manager',
|
||||||
|
scopes=[
|
||||||
|
core_entities.LifecycleControlScope.APPLICATION,
|
||||||
|
core_entities.LifecycleControlScope.PLATFORM,
|
||||||
|
],
|
||||||
|
)
|
||||||
case core_entities.LifecycleControlScope.PLUGIN.value:
|
case core_entities.LifecycleControlScope.PLUGIN.value:
|
||||||
self.logger.info("执行热重载 scope="+scope)
|
self.logger.info('执行热重载 scope=' + scope)
|
||||||
await self.plugin_mgr.destroy_plugins()
|
await self.plugin_mgr.destroy_plugins()
|
||||||
|
|
||||||
# 删除 sys.module 中所有的 plugins/* 下的模块
|
# 删除 sys.module 中所有的 plugins/* 下的模块
|
||||||
for mod in list(sys.modules.keys()):
|
for mod in list(sys.modules.keys()):
|
||||||
if mod.startswith("plugins."):
|
if mod.startswith('plugins.'):
|
||||||
del sys.modules[mod]
|
del sys.modules[mod]
|
||||||
|
|
||||||
self.plugin_mgr = plugin_mgr.PluginManager(self)
|
self.plugin_mgr = plugin_mgr.PluginManager(self)
|
||||||
@@ -204,7 +230,7 @@ class Application:
|
|||||||
await self.plugin_mgr.load_plugins()
|
await self.plugin_mgr.load_plugins()
|
||||||
await self.plugin_mgr.initialize_plugins()
|
await self.plugin_mgr.initialize_plugins()
|
||||||
case core_entities.LifecycleControlScope.PROVIDER.value:
|
case core_entities.LifecycleControlScope.PROVIDER.value:
|
||||||
self.logger.info("执行热重载 scope="+scope)
|
self.logger.info('执行热重载 scope=' + scope)
|
||||||
|
|
||||||
await self.tool_mgr.shutdown()
|
await self.tool_mgr.shutdown()
|
||||||
|
|
||||||
|
|||||||
+13
-16
@@ -7,29 +7,30 @@ import os
|
|||||||
from . import app
|
from . import app
|
||||||
from ..audit import identifier
|
from ..audit import identifier
|
||||||
from . import stage
|
from . import stage
|
||||||
from ..utils import constants
|
from ..utils import constants, importutil
|
||||||
|
|
||||||
# 引入启动阶段实现以便注册
|
# 引入启动阶段实现以便注册
|
||||||
from .stages import load_config, setup_logger, build_app, migrate, show_notes, genkeys
|
from . import stages
|
||||||
|
|
||||||
|
importutil.import_modules_in_pkg(stages)
|
||||||
|
|
||||||
|
|
||||||
stage_order = [
|
stage_order = [
|
||||||
"LoadConfigStage",
|
'LoadConfigStage',
|
||||||
"MigrationStage",
|
'MigrationStage',
|
||||||
"GenKeysStage",
|
'GenKeysStage',
|
||||||
"SetupLoggerStage",
|
'SetupLoggerStage',
|
||||||
"BuildAppStage",
|
'BuildAppStage',
|
||||||
"ShowNotesStage"
|
'ShowNotesStage',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application:
|
async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application:
|
||||||
|
|
||||||
# 生成标识符
|
# 生成标识符
|
||||||
identifier.init()
|
identifier.init()
|
||||||
|
|
||||||
# 确定是否为调试模式
|
# 确定是否为调试模式
|
||||||
if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]:
|
if 'DEBUG' in os.environ and os.environ['DEBUG'] in ['true', '1']:
|
||||||
constants.debug_mode = True
|
constants.debug_mode = True
|
||||||
|
|
||||||
ap = app.Application()
|
ap = app.Application()
|
||||||
@@ -50,21 +51,17 @@ async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application:
|
|||||||
|
|
||||||
async def main(loop: asyncio.AbstractEventLoop):
|
async def main(loop: asyncio.AbstractEventLoop):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# 挂系统信号处理
|
# 挂系统信号处理
|
||||||
import signal
|
import signal
|
||||||
|
|
||||||
ap: app.Application
|
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
def signal_handler(sig, frame):
|
||||||
print("[Signal] 程序退出.")
|
print('[Signal] 程序退出.')
|
||||||
# ap.shutdown()
|
# ap.shutdown()
|
||||||
os._exit(0)
|
os._exit(0)
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
app_inst = await make_app(loop)
|
app_inst = await make_app(loop)
|
||||||
ap = app_inst
|
|
||||||
await app_inst.run()
|
await app_inst.run()
|
||||||
except Exception as e:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
from ...config import manager as config_mgr
|
from ...config import manager as config_mgr
|
||||||
from ...config.impls import pymodule
|
|
||||||
|
|
||||||
|
|
||||||
load_python_module_config = config_mgr.load_python_module_config
|
load_python_module_config = config_mgr.load_python_module_config
|
||||||
|
|||||||
+43
-38
@@ -5,39 +5,39 @@ from ...utils import pkgmgr
|
|||||||
# 检查依赖,防止用户未安装
|
# 检查依赖,防止用户未安装
|
||||||
# 左边为引入名称,右边为依赖名称
|
# 左边为引入名称,右边为依赖名称
|
||||||
required_deps = {
|
required_deps = {
|
||||||
"requests": "requests",
|
'requests': 'requests',
|
||||||
"openai": "openai",
|
'openai': 'openai',
|
||||||
"anthropic": "anthropic",
|
'anthropic': 'anthropic',
|
||||||
"colorlog": "colorlog",
|
'colorlog': 'colorlog',
|
||||||
"aiocqhttp": "aiocqhttp",
|
'aiocqhttp': 'aiocqhttp',
|
||||||
"botpy": "qq-botpy-rc",
|
'botpy': 'qq-botpy-rc',
|
||||||
"PIL": "pillow",
|
'PIL': 'pillow',
|
||||||
"nakuru": "nakuru-project-idk",
|
'nakuru': 'nakuru-project-idk',
|
||||||
"tiktoken": "tiktoken",
|
'tiktoken': 'tiktoken',
|
||||||
"yaml": "pyyaml",
|
'yaml': 'pyyaml',
|
||||||
"aiohttp": "aiohttp",
|
'aiohttp': 'aiohttp',
|
||||||
"psutil": "psutil",
|
'psutil': 'psutil',
|
||||||
"async_lru": "async-lru",
|
'async_lru': 'async-lru',
|
||||||
"ollama": "ollama",
|
'ollama': 'ollama',
|
||||||
"quart": "quart",
|
'quart': 'quart',
|
||||||
"quart_cors": "quart-cors",
|
'quart_cors': 'quart-cors',
|
||||||
"sqlalchemy": "sqlalchemy[asyncio]",
|
'sqlalchemy': 'sqlalchemy[asyncio]',
|
||||||
"aiosqlite": "aiosqlite",
|
'aiosqlite': 'aiosqlite',
|
||||||
"aiofiles": "aiofiles",
|
'aiofiles': 'aiofiles',
|
||||||
"aioshutil": "aioshutil",
|
'aioshutil': 'aioshutil',
|
||||||
"argon2": "argon2-cffi",
|
'argon2': 'argon2-cffi',
|
||||||
"jwt": "pyjwt",
|
'jwt': 'pyjwt',
|
||||||
"Crypto": "pycryptodome",
|
'Crypto': 'pycryptodome',
|
||||||
"lark_oapi": "lark-oapi",
|
'lark_oapi': 'lark-oapi',
|
||||||
"discord": "discord.py",
|
'discord': 'discord.py',
|
||||||
"cryptography": "cryptography",
|
'cryptography': 'cryptography',
|
||||||
"gewechat_client": "gewechat-client",
|
'gewechat_client': 'gewechat-client',
|
||||||
"dingtalk_stream": "dingtalk_stream",
|
'dingtalk_stream': 'dingtalk_stream',
|
||||||
"dashscope": "dashscope",
|
'dashscope': 'dashscope',
|
||||||
"telegram": "python-telegram-bot",
|
'telegram': 'python-telegram-bot',
|
||||||
"certifi": "certifi",
|
'certifi': 'certifi',
|
||||||
"mcp": "mcp",
|
'mcp': 'mcp',
|
||||||
"sqlmodel": "sqlmodel",
|
'sqlmodel': 'sqlmodel',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -52,20 +52,25 @@ async def check_deps() -> list[str]:
|
|||||||
missing_deps.append(dep)
|
missing_deps.append(dep)
|
||||||
return missing_deps
|
return missing_deps
|
||||||
|
|
||||||
|
|
||||||
async def install_deps(deps: list[str]):
|
async def install_deps(deps: list[str]):
|
||||||
global required_deps
|
global required_deps
|
||||||
|
|
||||||
for dep in deps:
|
for dep in deps:
|
||||||
pip.main(["install", required_deps[dep]])
|
pip.main(['install', required_deps[dep]])
|
||||||
|
|
||||||
|
|
||||||
async def precheck_plugin_deps():
|
async def precheck_plugin_deps():
|
||||||
print('[Startup] Prechecking plugin dependencies...')
|
print('[Startup] Prechecking plugin dependencies...')
|
||||||
|
|
||||||
# 只有在plugins目录存在时才执行插件依赖安装
|
# 只有在plugins目录存在时才执行插件依赖安装
|
||||||
if os.path.exists("plugins"):
|
if os.path.exists('plugins'):
|
||||||
for dir in os.listdir("plugins"):
|
for dir in os.listdir('plugins'):
|
||||||
subdir = os.path.join("plugins", dir)
|
subdir = os.path.join('plugins', dir)
|
||||||
if not os.path.isdir(subdir):
|
if not os.path.isdir(subdir):
|
||||||
continue
|
continue
|
||||||
if 'requirements.txt' in os.listdir(subdir):
|
if 'requirements.txt' in os.listdir(subdir):
|
||||||
pkgmgr.install_requirements(os.path.join(subdir, 'requirements.txt'), extra_params=['-q', '-q', '-q'])
|
pkgmgr.install_requirements(
|
||||||
|
os.path.join(subdir, 'requirements.txt'),
|
||||||
|
extra_params=['-q', '-q', '-q'],
|
||||||
|
)
|
||||||
|
|||||||
@@ -2,23 +2,23 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
|
|
||||||
|
|
||||||
required_files = {
|
required_files = {
|
||||||
"plugins/__init__.py": "templates/__init__.py",
|
'plugins/__init__.py': 'templates/__init__.py',
|
||||||
"data/config.yaml": "templates/config.yaml",
|
'data/config.yaml': 'templates/config.yaml',
|
||||||
}
|
}
|
||||||
|
|
||||||
required_paths = [
|
required_paths = [
|
||||||
"temp",
|
'temp',
|
||||||
"data",
|
'data',
|
||||||
"data/metadata",
|
'data/metadata',
|
||||||
"data/logs",
|
'data/logs',
|
||||||
"data/labels",
|
'data/labels',
|
||||||
"plugins"
|
'plugins',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def generate_files() -> list[str]:
|
async def generate_files() -> list[str]:
|
||||||
global required_files, required_paths
|
global required_files, required_paths
|
||||||
|
|
||||||
|
|||||||
+20
-16
@@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -9,11 +8,11 @@ from ...utils import constants
|
|||||||
|
|
||||||
|
|
||||||
log_colors_config = {
|
log_colors_config = {
|
||||||
"DEBUG": "green", # cyan white
|
'DEBUG': 'green', # cyan white
|
||||||
"INFO": "white",
|
'INFO': 'white',
|
||||||
"WARNING": "yellow",
|
'WARNING': 'yellow',
|
||||||
"ERROR": "red",
|
'ERROR': 'red',
|
||||||
"CRITICAL": "cyan",
|
'CRITICAL': 'cyan',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -27,26 +26,31 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging.
|
|||||||
if constants.debug_mode:
|
if constants.debug_mode:
|
||||||
level = logging.DEBUG
|
level = logging.DEBUG
|
||||||
|
|
||||||
log_file_name = "data/logs/langbot-%s.log" % time.strftime(
|
log_file_name = 'data/logs/langbot-%s.log' % time.strftime(
|
||||||
"%Y-%m-%d", time.localtime()
|
'%Y-%m-%d', time.localtime()
|
||||||
)
|
)
|
||||||
|
|
||||||
qcg_logger = logging.getLogger("langbot")
|
qcg_logger = logging.getLogger('langbot')
|
||||||
|
|
||||||
qcg_logger.setLevel(level)
|
qcg_logger.setLevel(level)
|
||||||
|
|
||||||
color_formatter = colorlog.ColoredFormatter(
|
color_formatter = colorlog.ColoredFormatter(
|
||||||
fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : %(message)s",
|
fmt='%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : %(message)s',
|
||||||
datefmt="%m-%d %H:%M:%S",
|
datefmt='%m-%d %H:%M:%S',
|
||||||
log_colors=log_colors_config,
|
log_colors=log_colors_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
stream_handler = logging.StreamHandler(sys.stdout)
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
# stream_handler.setLevel(level)
|
# stream_handler.setLevel(level)
|
||||||
# stream_handler.setFormatter(color_formatter)
|
# stream_handler.setFormatter(color_formatter)
|
||||||
stream_handler.stream = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1)
|
stream_handler.stream = open(
|
||||||
|
sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1
|
||||||
|
)
|
||||||
|
|
||||||
log_handlers: list[logging.Handler] = [stream_handler, logging.FileHandler(log_file_name, encoding='utf-8')]
|
log_handlers: list[logging.Handler] = [
|
||||||
|
stream_handler,
|
||||||
|
logging.FileHandler(log_file_name, encoding='utf-8'),
|
||||||
|
]
|
||||||
log_handlers += extra_handlers if extra_handlers is not None else []
|
log_handlers += extra_handlers if extra_handlers is not None else []
|
||||||
|
|
||||||
for handler in log_handlers:
|
for handler in log_handlers:
|
||||||
@@ -54,13 +58,13 @@ async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging.
|
|||||||
handler.setFormatter(color_formatter)
|
handler.setFormatter(color_formatter)
|
||||||
qcg_logger.addHandler(handler)
|
qcg_logger.addHandler(handler)
|
||||||
|
|
||||||
qcg_logger.debug("日志初始化完成,日志级别:%s" % level)
|
qcg_logger.debug('日志初始化完成,日志级别:%s' % level)
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.CRITICAL, # 设置日志输出格式
|
level=logging.CRITICAL, # 设置日志输出格式
|
||||||
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
|
format='[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s',
|
||||||
# 日志输出的格式
|
# 日志输出的格式
|
||||||
# -8表示占位符,让输出左对齐,输出长度都为8位
|
# -8表示占位符,让输出左对齐,输出长度都为8位
|
||||||
datefmt="%Y-%m-%d %H:%M:%S", # 时间输出的格式
|
datefmt='%Y-%m-%d %H:%M:%S', # 时间输出的格式
|
||||||
handlers=[logging.NullHandler()],
|
handlers=[logging.NullHandler()],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+26
-15
@@ -8,21 +8,18 @@ import asyncio
|
|||||||
import pydantic.v1 as pydantic
|
import pydantic.v1 as pydantic
|
||||||
|
|
||||||
from ..provider import entities as llm_entities
|
from ..provider import entities as llm_entities
|
||||||
from ..provider.modelmgr import entities, modelmgr, requester
|
from ..provider.modelmgr import requester
|
||||||
from ..provider.tools import entities as tools_entities
|
from ..provider.tools import entities as tools_entities
|
||||||
from ..platform import adapter as msadapter
|
from ..platform import adapter as msadapter
|
||||||
from ..platform.types import message as platform_message
|
from ..platform.types import message as platform_message
|
||||||
from ..platform.types import events as platform_events
|
from ..platform.types import events as platform_events
|
||||||
from ..platform.types import entities as platform_entities
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LifecycleControlScope(enum.Enum):
|
class LifecycleControlScope(enum.Enum):
|
||||||
|
APPLICATION = 'application'
|
||||||
APPLICATION = "application"
|
PLATFORM = 'platform'
|
||||||
PLATFORM = "platform"
|
PLUGIN = 'plugin'
|
||||||
PLUGIN = "plugin"
|
PROVIDER = 'provider'
|
||||||
PROVIDER = "provider"
|
|
||||||
|
|
||||||
|
|
||||||
class LauncherTypes(enum.Enum):
|
class LauncherTypes(enum.Enum):
|
||||||
@@ -89,14 +86,17 @@ class Query(pydantic.BaseModel):
|
|||||||
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
|
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
|
||||||
"""使用的函数,由前置处理器阶段设置"""
|
"""使用的函数,由前置处理器阶段设置"""
|
||||||
|
|
||||||
resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]] = []
|
resp_messages: (
|
||||||
|
typing.Optional[list[llm_entities.Message]]
|
||||||
|
| typing.Optional[list[platform_message.MessageChain]]
|
||||||
|
) = []
|
||||||
"""由Process阶段生成的回复消息对象列表"""
|
"""由Process阶段生成的回复消息对象列表"""
|
||||||
|
|
||||||
resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None
|
resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None
|
||||||
"""回复消息链,从resp_messages包装而得"""
|
"""回复消息链,从resp_messages包装而得"""
|
||||||
|
|
||||||
# ======= 内部保留 =======
|
# ======= 内部保留 =======
|
||||||
current_stage: "pkg.pipeline.pipelinemgr.StageInstContainer" = None
|
current_stage = None # pkg.pipeline.pipelinemgr.StageInstContainer
|
||||||
"""当前所处阶段"""
|
"""当前所处阶段"""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@@ -130,9 +130,13 @@ class Conversation(pydantic.BaseModel):
|
|||||||
|
|
||||||
messages: list[llm_entities.Message]
|
messages: list[llm_entities.Message]
|
||||||
|
|
||||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
create_time: typing.Optional[datetime.datetime] = pydantic.Field(
|
||||||
|
default_factory=datetime.datetime.now
|
||||||
|
)
|
||||||
|
|
||||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
update_time: typing.Optional[datetime.datetime] = pydantic.Field(
|
||||||
|
default_factory=datetime.datetime.now
|
||||||
|
)
|
||||||
|
|
||||||
use_llm_model: requester.RuntimeLLMModel
|
use_llm_model: requester.RuntimeLLMModel
|
||||||
|
|
||||||
@@ -147,6 +151,7 @@ class Conversation(pydantic.BaseModel):
|
|||||||
|
|
||||||
class Session(pydantic.BaseModel):
|
class Session(pydantic.BaseModel):
|
||||||
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""
|
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""
|
||||||
|
|
||||||
launcher_type: LauncherTypes
|
launcher_type: LauncherTypes
|
||||||
|
|
||||||
launcher_id: typing.Union[int, str]
|
launcher_id: typing.Union[int, str]
|
||||||
@@ -157,11 +162,17 @@ class Session(pydantic.BaseModel):
|
|||||||
|
|
||||||
using_conversation: typing.Optional[Conversation] = None
|
using_conversation: typing.Optional[Conversation] = None
|
||||||
|
|
||||||
conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list)
|
conversations: typing.Optional[list[Conversation]] = pydantic.Field(
|
||||||
|
default_factory=list
|
||||||
|
)
|
||||||
|
|
||||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
create_time: typing.Optional[datetime.datetime] = pydantic.Field(
|
||||||
|
default_factory=datetime.datetime.now
|
||||||
|
)
|
||||||
|
|
||||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
update_time: typing.Optional[datetime.datetime] = pydantic.Field(
|
||||||
|
default_factory=datetime.datetime.now
|
||||||
|
)
|
||||||
|
|
||||||
semaphore: typing.Optional[asyncio.Semaphore] = None
|
semaphore: typing.Optional[asyncio.Semaphore] = None
|
||||||
"""当前会话的信号量,用于限制并发"""
|
"""当前会话的信号量,用于限制并发"""
|
||||||
|
|||||||
@@ -9,9 +9,10 @@ from . import app
|
|||||||
preregistered_migrations: list[typing.Type[Migration]] = []
|
preregistered_migrations: list[typing.Type[Migration]] = []
|
||||||
"""当前阶段暂不支持扩展"""
|
"""当前阶段暂不支持扩展"""
|
||||||
|
|
||||||
|
|
||||||
def migration_class(name: str, number: int):
|
def migration_class(name: str, number: int):
|
||||||
"""注册一个迁移
|
"""注册一个迁移"""
|
||||||
"""
|
|
||||||
def decorator(cls: typing.Type[Migration]) -> typing.Type[Migration]:
|
def decorator(cls: typing.Type[Migration]) -> typing.Type[Migration]:
|
||||||
cls.name = name
|
cls.name = name
|
||||||
cls.number = number
|
cls.number = number
|
||||||
@@ -22,8 +23,7 @@ def migration_class(name: str, number: int):
|
|||||||
|
|
||||||
|
|
||||||
class Migration(abc.ABC):
|
class Migration(abc.ABC):
|
||||||
"""一个版本的迁移
|
"""一个版本的迁移"""
|
||||||
"""
|
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
@@ -36,12 +36,10 @@ class Migration(abc.ABC):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def need_migrate(self) -> bool:
|
async def need_migrate(self) -> bool:
|
||||||
"""判断当前环境是否需要运行此迁移
|
"""判断当前环境是否需要运行此迁移"""
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移
|
"""执行迁移"""
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,26 +1,26 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
|
|
||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("sensitive-word-migration", 1)
|
@migration.migration_class('sensitive-word-migration', 1)
|
||||||
class SensitiveWordMigration(migration.Migration):
|
class SensitiveWordMigration(migration.Migration):
|
||||||
"""敏感词迁移
|
"""敏感词迁移"""
|
||||||
"""
|
|
||||||
|
|
||||||
async def need_migrate(self) -> bool:
|
async def need_migrate(self) -> bool:
|
||||||
"""判断当前环境是否需要运行此迁移
|
"""判断当前环境是否需要运行此迁移"""
|
||||||
"""
|
return os.path.exists(
|
||||||
return os.path.exists("data/config/sensitive-words.json") and not os.path.exists("data/metadata/sensitive-words.json")
|
'data/config/sensitive-words.json'
|
||||||
|
) and not os.path.exists('data/metadata/sensitive-words.json')
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移
|
"""执行迁移"""
|
||||||
"""
|
|
||||||
# 移动文件
|
# 移动文件
|
||||||
os.rename("data/config/sensitive-words.json", "data/metadata/sensitive-words.json")
|
os.rename(
|
||||||
|
'data/config/sensitive-words.json', 'data/metadata/sensitive-words.json'
|
||||||
|
)
|
||||||
|
|
||||||
# 重新加载配置
|
# 重新加载配置
|
||||||
await self.ap.sensitive_meta.load_config()
|
await self.ap.sensitive_meta.load_config()
|
||||||
|
|||||||
@@ -3,19 +3,16 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("openai-config-migration", 2)
|
@migration.migration_class('openai-config-migration', 2)
|
||||||
class OpenAIConfigMigration(migration.Migration):
|
class OpenAIConfigMigration(migration.Migration):
|
||||||
"""OpenAI配置迁移
|
"""OpenAI配置迁移"""
|
||||||
"""
|
|
||||||
|
|
||||||
async def need_migrate(self) -> bool:
|
async def need_migrate(self) -> bool:
|
||||||
"""判断当前环境是否需要运行此迁移
|
"""判断当前环境是否需要运行此迁移"""
|
||||||
"""
|
|
||||||
return 'openai-config' in self.ap.provider_cfg.data
|
return 'openai-config' in self.ap.provider_cfg.data
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移
|
"""执行迁移"""
|
||||||
"""
|
|
||||||
old_openai_config = self.ap.provider_cfg.data['openai-config'].copy()
|
old_openai_config = self.ap.provider_cfg.data['openai-config'].copy()
|
||||||
|
|
||||||
if 'keys' not in self.ap.provider_cfg.data:
|
if 'keys' not in self.ap.provider_cfg.data:
|
||||||
@@ -26,7 +23,9 @@ class OpenAIConfigMigration(migration.Migration):
|
|||||||
|
|
||||||
self.ap.provider_cfg.data['keys']['openai'] = old_openai_config['api-keys']
|
self.ap.provider_cfg.data['keys']['openai'] = old_openai_config['api-keys']
|
||||||
|
|
||||||
self.ap.provider_cfg.data['model'] = old_openai_config['chat-completions-params']['model']
|
self.ap.provider_cfg.data['model'] = old_openai_config[
|
||||||
|
'chat-completions-params'
|
||||||
|
]['model']
|
||||||
|
|
||||||
del old_openai_config['chat-completions-params']['model']
|
del old_openai_config['chat-completions-params']['model']
|
||||||
|
|
||||||
|
|||||||
@@ -3,26 +3,23 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("anthropic-requester-config-completion", 3)
|
@migration.migration_class('anthropic-requester-config-completion', 3)
|
||||||
class AnthropicRequesterConfigCompletionMigration(migration.Migration):
|
class AnthropicRequesterConfigCompletionMigration(migration.Migration):
|
||||||
"""OpenAI配置迁移
|
"""OpenAI配置迁移"""
|
||||||
"""
|
|
||||||
|
|
||||||
async def need_migrate(self) -> bool:
|
async def need_migrate(self) -> bool:
|
||||||
"""判断当前环境是否需要运行此迁移
|
"""判断当前环境是否需要运行此迁移"""
|
||||||
"""
|
return (
|
||||||
return 'anthropic-messages' not in self.ap.provider_cfg.data['requester'] \
|
'anthropic-messages' not in self.ap.provider_cfg.data['requester']
|
||||||
or 'anthropic' not in self.ap.provider_cfg.data['keys']
|
or 'anthropic' not in self.ap.provider_cfg.data['keys']
|
||||||
|
)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移
|
"""执行迁移"""
|
||||||
"""
|
|
||||||
if 'anthropic-messages' not in self.ap.provider_cfg.data['requester']:
|
if 'anthropic-messages' not in self.ap.provider_cfg.data['requester']:
|
||||||
self.ap.provider_cfg.data['requester']['anthropic-messages'] = {
|
self.ap.provider_cfg.data['requester']['anthropic-messages'] = {
|
||||||
'base-url': 'https://api.anthropic.com',
|
'base-url': 'https://api.anthropic.com',
|
||||||
'args': {
|
'args': {'max_tokens': 1024},
|
||||||
'max_tokens': 1024
|
|
||||||
},
|
|
||||||
'timeout': 120,
|
'timeout': 120,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,20 +3,19 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("moonshot-config-completion", 4)
|
@migration.migration_class('moonshot-config-completion', 4)
|
||||||
class MoonshotConfigCompletionMigration(migration.Migration):
|
class MoonshotConfigCompletionMigration(migration.Migration):
|
||||||
"""OpenAI配置迁移
|
"""OpenAI配置迁移"""
|
||||||
"""
|
|
||||||
|
|
||||||
async def need_migrate(self) -> bool:
|
async def need_migrate(self) -> bool:
|
||||||
"""判断当前环境是否需要运行此迁移
|
"""判断当前环境是否需要运行此迁移"""
|
||||||
"""
|
return (
|
||||||
return 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester'] \
|
'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester']
|
||||||
or 'moonshot' not in self.ap.provider_cfg.data['keys']
|
or 'moonshot' not in self.ap.provider_cfg.data['keys']
|
||||||
|
)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移
|
"""执行迁移"""
|
||||||
"""
|
|
||||||
if 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester']:
|
if 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester']:
|
||||||
self.ap.provider_cfg.data['requester']['moonshot-chat-completions'] = {
|
self.ap.provider_cfg.data['requester']['moonshot-chat-completions'] = {
|
||||||
'base-url': 'https://api.moonshot.cn/v1',
|
'base-url': 'https://api.moonshot.cn/v1',
|
||||||
|
|||||||
@@ -3,20 +3,19 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("deepseek-config-completion", 5)
|
@migration.migration_class('deepseek-config-completion', 5)
|
||||||
class DeepseekConfigCompletionMigration(migration.Migration):
|
class DeepseekConfigCompletionMigration(migration.Migration):
|
||||||
"""OpenAI配置迁移
|
"""OpenAI配置迁移"""
|
||||||
"""
|
|
||||||
|
|
||||||
async def need_migrate(self) -> bool:
|
async def need_migrate(self) -> bool:
|
||||||
"""判断当前环境是否需要运行此迁移
|
"""判断当前环境是否需要运行此迁移"""
|
||||||
"""
|
return (
|
||||||
return 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester'] \
|
'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester']
|
||||||
or 'deepseek' not in self.ap.provider_cfg.data['keys']
|
or 'deepseek' not in self.ap.provider_cfg.data['keys']
|
||||||
|
)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移
|
"""执行迁移"""
|
||||||
"""
|
|
||||||
if 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester']:
|
if 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester']:
|
||||||
self.ap.provider_cfg.data['requester']['deepseek-chat-completions'] = {
|
self.ap.provider_cfg.data['requester']['deepseek-chat-completions'] = {
|
||||||
'base-url': 'https://api.deepseek.com',
|
'base-url': 'https://api.deepseek.com',
|
||||||
|
|||||||
@@ -3,17 +3,17 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("vision-config", 6)
|
@migration.migration_class('vision-config', 6)
|
||||||
class VisionConfigMigration(migration.Migration):
|
class VisionConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
async def need_migrate(self) -> bool:
|
async def need_migrate(self) -> bool:
|
||||||
"""判断当前环境是否需要运行此迁移"""
|
"""判断当前环境是否需要运行此迁移"""
|
||||||
return "enable-vision" not in self.ap.provider_cfg.data
|
return 'enable-vision' not in self.ap.provider_cfg.data
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
if "enable-vision" not in self.ap.provider_cfg.data:
|
if 'enable-vision' not in self.ap.provider_cfg.data:
|
||||||
self.ap.provider_cfg.data["enable-vision"] = False
|
self.ap.provider_cfg.data['enable-vision'] = False
|
||||||
|
|
||||||
await self.ap.provider_cfg.dump_config()
|
await self.ap.provider_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,18 +3,20 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("qcg-center-url-config", 7)
|
@migration.migration_class('qcg-center-url-config', 7)
|
||||||
class QCGCenterURLConfigMigration(migration.Migration):
|
class QCGCenterURLConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
async def need_migrate(self) -> bool:
|
async def need_migrate(self) -> bool:
|
||||||
"""判断当前环境是否需要运行此迁移"""
|
"""判断当前环境是否需要运行此迁移"""
|
||||||
return "qcg-center-url" not in self.ap.system_cfg.data
|
return 'qcg-center-url' not in self.ap.system_cfg.data
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
|
|
||||||
if "qcg-center-url" not in self.ap.system_cfg.data:
|
if 'qcg-center-url' not in self.ap.system_cfg.data:
|
||||||
self.ap.system_cfg.data["qcg-center-url"] = "https://api.qchatgpt.rockchin.top/api/v2"
|
self.ap.system_cfg.data['qcg-center-url'] = (
|
||||||
|
'https://api.qchatgpt.rockchin.top/api/v2'
|
||||||
|
)
|
||||||
|
|
||||||
await self.ap.system_cfg.dump_config()
|
await self.ap.system_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,27 +3,27 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("ad-fixwin-cfg-migration", 8)
|
@migration.migration_class('ad-fixwin-cfg-migration', 8)
|
||||||
class AdFixwinConfigMigration(migration.Migration):
|
class AdFixwinConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
async def need_migrate(self) -> bool:
|
async def need_migrate(self) -> bool:
|
||||||
"""判断当前环境是否需要运行此迁移"""
|
"""判断当前环境是否需要运行此迁移"""
|
||||||
return isinstance(
|
return isinstance(
|
||||||
self.ap.pipeline_cfg.data["rate-limit"]["fixwin"]["default"],
|
self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default'], int
|
||||||
int
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
|
|
||||||
for session_name in self.ap.pipeline_cfg.data["rate-limit"]["fixwin"]:
|
for session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
|
||||||
|
|
||||||
temp_dict = {
|
temp_dict = {
|
||||||
"window-size": 60,
|
'window-size': 60,
|
||||||
"limit": self.ap.pipeline_cfg.data["rate-limit"]["fixwin"][session_name]
|
'limit': self.ap.pipeline_cfg.data['rate-limit']['fixwin'][
|
||||||
|
session_name
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
self.ap.pipeline_cfg.data["rate-limit"]["fixwin"][session_name] = temp_dict
|
self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name] = temp_dict
|
||||||
|
|
||||||
await self.ap.pipeline_cfg.dump_config()
|
await self.ap.pipeline_cfg.dump_config()
|
||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("msg-truncator-cfg-migration", 9)
|
@migration.migration_class('msg-truncator-cfg-migration', 9)
|
||||||
class MsgTruncatorConfigMigration(migration.Migration):
|
class MsgTruncatorConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
@@ -16,9 +16,7 @@ class MsgTruncatorConfigMigration(migration.Migration):
|
|||||||
|
|
||||||
self.ap.pipeline_cfg.data['msg-truncate'] = {
|
self.ap.pipeline_cfg.data['msg-truncate'] = {
|
||||||
'method': 'round',
|
'method': 'round',
|
||||||
'round': {
|
'round': {'max-round': 10},
|
||||||
'max-round': 10
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
await self.ap.pipeline_cfg.dump_config()
|
await self.ap.pipeline_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("ollama-requester-config", 10)
|
@migration.migration_class('ollama-requester-config', 10)
|
||||||
class MsgTruncatorConfigMigration(migration.Migration):
|
class MsgTruncatorConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
@@ -15,9 +15,9 @@ class MsgTruncatorConfigMigration(migration.Migration):
|
|||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
|
|
||||||
self.ap.provider_cfg.data['requester']['ollama-chat'] = {
|
self.ap.provider_cfg.data['requester']['ollama-chat'] = {
|
||||||
"base-url": "http://127.0.0.1:11434",
|
'base-url': 'http://127.0.0.1:11434',
|
||||||
"args": {},
|
'args': {},
|
||||||
"timeout": 600
|
'timeout': 600,
|
||||||
}
|
}
|
||||||
|
|
||||||
await self.ap.provider_cfg.dump_config()
|
await self.ap.provider_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("command-prefix-config", 11)
|
@migration.migration_class('command-prefix-config', 11)
|
||||||
class CommandPrefixConfigMigration(migration.Migration):
|
class CommandPrefixConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
@@ -14,8 +14,6 @@ class CommandPrefixConfigMigration(migration.Migration):
|
|||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
|
|
||||||
self.ap.command_cfg.data['command-prefix'] = [
|
self.ap.command_cfg.data['command-prefix'] = ['!', '!']
|
||||||
"!", "!"
|
|
||||||
]
|
|
||||||
|
|
||||||
await self.ap.command_cfg.dump_config()
|
await self.ap.command_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("runner-config", 12)
|
@migration.migration_class('runner-config', 12)
|
||||||
class RunnerConfigMigration(migration.Migration):
|
class RunnerConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
|
|||||||
@@ -3,29 +3,30 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("http-api-config", 13)
|
@migration.migration_class('http-api-config', 13)
|
||||||
class HttpApiConfigMigration(migration.Migration):
|
class HttpApiConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
async def need_migrate(self) -> bool:
|
async def need_migrate(self) -> bool:
|
||||||
"""判断当前环境是否需要运行此迁移"""
|
"""判断当前环境是否需要运行此迁移"""
|
||||||
return 'http-api' not in self.ap.system_cfg.data or "persistence" not in self.ap.system_cfg.data
|
return (
|
||||||
|
'http-api' not in self.ap.system_cfg.data
|
||||||
|
or 'persistence' not in self.ap.system_cfg.data
|
||||||
|
)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
|
|
||||||
self.ap.system_cfg.data['http-api'] = {
|
self.ap.system_cfg.data['http-api'] = {
|
||||||
"enable": True,
|
'enable': True,
|
||||||
"host": "0.0.0.0",
|
'host': '0.0.0.0',
|
||||||
"port": 5300,
|
'port': 5300,
|
||||||
"jwt-expire": 604800
|
'jwt-expire': 604800,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.ap.system_cfg.data['persistence'] = {
|
self.ap.system_cfg.data['persistence'] = {
|
||||||
"sqlite": {
|
'sqlite': {'path': 'data/persistence.db'},
|
||||||
"path": "data/persistence.db"
|
'use': 'sqlite',
|
||||||
},
|
|
||||||
"use": "sqlite"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
await self.ap.system_cfg.dump_config()
|
await self.ap.system_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,20 +3,20 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("force-delay-config", 14)
|
@migration.migration_class('force-delay-config', 14)
|
||||||
class ForceDelayConfigMigration(migration.Migration):
|
class ForceDelayConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
async def need_migrate(self) -> bool:
|
async def need_migrate(self) -> bool:
|
||||||
"""判断当前环境是否需要运行此迁移"""
|
"""判断当前环境是否需要运行此迁移"""
|
||||||
return type(self.ap.platform_cfg.data['force-delay']) == list
|
return isinstance(self.ap.platform_cfg.data['force-delay'], list)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
|
|
||||||
self.ap.platform_cfg.data['force-delay'] = {
|
self.ap.platform_cfg.data['force-delay'] = {
|
||||||
"min": self.ap.platform_cfg.data['force-delay'][0],
|
'min': self.ap.platform_cfg.data['force-delay'][0],
|
||||||
"max": self.ap.platform_cfg.data['force-delay'][1]
|
'max': self.ap.platform_cfg.data['force-delay'][1],
|
||||||
}
|
}
|
||||||
|
|
||||||
await self.ap.platform_cfg.dump_config()
|
await self.ap.platform_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,24 +3,25 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("gitee-ai-config", 15)
|
@migration.migration_class('gitee-ai-config', 15)
|
||||||
class GiteeAIConfigMigration(migration.Migration):
|
class GiteeAIConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
async def need_migrate(self) -> bool:
|
async def need_migrate(self) -> bool:
|
||||||
"""判断当前环境是否需要运行此迁移"""
|
"""判断当前环境是否需要运行此迁移"""
|
||||||
return 'gitee-ai-chat-completions' not in self.ap.provider_cfg.data['requester'] or 'gitee-ai' not in self.ap.provider_cfg.data['keys']
|
return (
|
||||||
|
'gitee-ai-chat-completions' not in self.ap.provider_cfg.data['requester']
|
||||||
|
or 'gitee-ai' not in self.ap.provider_cfg.data['keys']
|
||||||
|
)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
self.ap.provider_cfg.data['requester']['gitee-ai-chat-completions'] = {
|
self.ap.provider_cfg.data['requester']['gitee-ai-chat-completions'] = {
|
||||||
"base-url": "https://ai.gitee.com/v1",
|
'base-url': 'https://ai.gitee.com/v1',
|
||||||
"args": {},
|
'args': {},
|
||||||
"timeout": 120
|
'timeout': 120,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.ap.provider_cfg.data['keys']['gitee-ai'] = [
|
self.ap.provider_cfg.data['keys']['gitee-ai'] = ['XXXXX']
|
||||||
"XXXXX"
|
|
||||||
]
|
|
||||||
|
|
||||||
await self.ap.provider_cfg.dump_config()
|
await self.ap.provider_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("dify-service-api-config", 16)
|
@migration.migration_class('dify-service-api-config', 16)
|
||||||
class DifyServiceAPICfgMigration(migration.Migration):
|
class DifyServiceAPICfgMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
@@ -14,15 +14,10 @@ class DifyServiceAPICfgMigration(migration.Migration):
|
|||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
self.ap.provider_cfg.data['dify-service-api'] = {
|
self.ap.provider_cfg.data['dify-service-api'] = {
|
||||||
"base-url": "https://api.dify.ai/v1",
|
'base-url': 'https://api.dify.ai/v1',
|
||||||
"app-type": "chat",
|
'app-type': 'chat',
|
||||||
"chat": {
|
'chat': {'api-key': 'app-1234567890'},
|
||||||
"api-key": "app-1234567890"
|
'workflow': {'api-key': 'app-1234567890', 'output-key': 'summary'},
|
||||||
},
|
|
||||||
"workflow": {
|
|
||||||
"api-key": "app-1234567890",
|
|
||||||
"output-key": "summary"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
await self.ap.provider_cfg.dump_config()
|
await self.ap.provider_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,22 +3,26 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("dify-api-timeout-params", 17)
|
@migration.migration_class('dify-api-timeout-params', 17)
|
||||||
class DifyAPITimeoutParamsMigration(migration.Migration):
|
class DifyAPITimeoutParamsMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
async def need_migrate(self) -> bool:
|
async def need_migrate(self) -> bool:
|
||||||
"""判断当前环境是否需要运行此迁移"""
|
"""判断当前环境是否需要运行此迁移"""
|
||||||
return 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['chat'] or 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['workflow'] \
|
return (
|
||||||
|
'timeout' not in self.ap.provider_cfg.data['dify-service-api']['chat']
|
||||||
|
or 'timeout'
|
||||||
|
not in self.ap.provider_cfg.data['dify-service-api']['workflow']
|
||||||
or 'agent' not in self.ap.provider_cfg.data['dify-service-api']
|
or 'agent' not in self.ap.provider_cfg.data['dify-service-api']
|
||||||
|
)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
self.ap.provider_cfg.data['dify-service-api']['chat']['timeout'] = 120
|
self.ap.provider_cfg.data['dify-service-api']['chat']['timeout'] = 120
|
||||||
self.ap.provider_cfg.data['dify-service-api']['workflow']['timeout'] = 120
|
self.ap.provider_cfg.data['dify-service-api']['workflow']['timeout'] = 120
|
||||||
self.ap.provider_cfg.data['dify-service-api']['agent'] = {
|
self.ap.provider_cfg.data['dify-service-api']['agent'] = {
|
||||||
"api-key": "app-1234567890",
|
'api-key': 'app-1234567890',
|
||||||
"timeout": 120
|
'timeout': 120,
|
||||||
}
|
}
|
||||||
|
|
||||||
await self.ap.provider_cfg.dump_config()
|
await self.ap.provider_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("xai-config", 18)
|
@migration.migration_class('xai-config', 18)
|
||||||
class XaiConfigMigration(migration.Migration):
|
class XaiConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
@@ -14,12 +14,10 @@ class XaiConfigMigration(migration.Migration):
|
|||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
self.ap.provider_cfg.data['requester']['xai-chat-completions'] = {
|
self.ap.provider_cfg.data['requester']['xai-chat-completions'] = {
|
||||||
"base-url": "https://api.x.ai/v1",
|
'base-url': 'https://api.x.ai/v1',
|
||||||
"args": {},
|
'args': {},
|
||||||
"timeout": 120
|
'timeout': 120,
|
||||||
}
|
}
|
||||||
self.ap.provider_cfg.data['keys']['xai'] = [
|
self.ap.provider_cfg.data['keys']['xai'] = ['xai-1234567890']
|
||||||
"xai-1234567890"
|
|
||||||
]
|
|
||||||
|
|
||||||
await self.ap.provider_cfg.dump_config()
|
await self.ap.provider_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("zhipuai-config", 19)
|
@migration.migration_class('zhipuai-config', 19)
|
||||||
class ZhipuaiConfigMigration(migration.Migration):
|
class ZhipuaiConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
@@ -14,12 +14,10 @@ class ZhipuaiConfigMigration(migration.Migration):
|
|||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
self.ap.provider_cfg.data['requester']['zhipuai-chat-completions'] = {
|
self.ap.provider_cfg.data['requester']['zhipuai-chat-completions'] = {
|
||||||
"base-url": "https://open.bigmodel.cn/api/paas/v4",
|
'base-url': 'https://open.bigmodel.cn/api/paas/v4',
|
||||||
"args": {},
|
'args': {},
|
||||||
"timeout": 120
|
'timeout': 120,
|
||||||
}
|
}
|
||||||
self.ap.provider_cfg.data['keys']['zhipuai'] = [
|
self.ap.provider_cfg.data['keys']['zhipuai'] = ['xxxxxxx']
|
||||||
"xxxxxxx"
|
|
||||||
]
|
|
||||||
|
|
||||||
await self.ap.provider_cfg.dump_config()
|
await self.ap.provider_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("wecom-config", 20)
|
@migration.migration_class('wecom-config', 20)
|
||||||
class WecomConfigMigration(migration.Migration):
|
class WecomConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
@@ -19,16 +19,18 @@ class WecomConfigMigration(migration.Migration):
|
|||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
self.ap.platform_cfg.data['platform-adapters'].append({
|
self.ap.platform_cfg.data['platform-adapters'].append(
|
||||||
"adapter": "wecom",
|
{
|
||||||
"enable": False,
|
'adapter': 'wecom',
|
||||||
"host": "0.0.0.0",
|
'enable': False,
|
||||||
"port": 2290,
|
'host': '0.0.0.0',
|
||||||
"corpid": "",
|
'port': 2290,
|
||||||
"secret": "",
|
'corpid': '',
|
||||||
"token": "",
|
'secret': '',
|
||||||
"EncodingAESKey": "",
|
'token': '',
|
||||||
"contacts_secret": ""
|
'EncodingAESKey': '',
|
||||||
})
|
'contacts_secret': '',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
await self.ap.platform_cfg.dump_config()
|
await self.ap.platform_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("lark-config", 21)
|
@migration.migration_class('lark-config', 21)
|
||||||
class LarkConfigMigration(migration.Migration):
|
class LarkConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
@@ -19,15 +19,17 @@ class LarkConfigMigration(migration.Migration):
|
|||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
self.ap.platform_cfg.data['platform-adapters'].append({
|
self.ap.platform_cfg.data['platform-adapters'].append(
|
||||||
"adapter": "lark",
|
{
|
||||||
"enable": False,
|
'adapter': 'lark',
|
||||||
"app_id": "cli_abcdefgh",
|
'enable': False,
|
||||||
"app_secret": "XXXXXXXXXX",
|
'app_id': 'cli_abcdefgh',
|
||||||
"bot_name": "LangBot",
|
'app_secret': 'XXXXXXXXXX',
|
||||||
"enable-webhook": False,
|
'bot_name': 'LangBot',
|
||||||
"port": 2285,
|
'enable-webhook': False,
|
||||||
"encrypt-key": "xxxxxxxxx"
|
'port': 2285,
|
||||||
})
|
'encrypt-key': 'xxxxxxxxx',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
await self.ap.platform_cfg.dump_config()
|
await self.ap.platform_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("lmstudio-config", 22)
|
@migration.migration_class('lmstudio-config', 22)
|
||||||
class LmStudioConfigMigration(migration.Migration):
|
class LmStudioConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
@@ -15,9 +15,9 @@ class LmStudioConfigMigration(migration.Migration):
|
|||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
self.ap.provider_cfg.data['requester']['lmstudio-chat-completions'] = {
|
self.ap.provider_cfg.data['requester']['lmstudio-chat-completions'] = {
|
||||||
"base-url": "http://127.0.0.1:1234/v1",
|
'base-url': 'http://127.0.0.1:1234/v1',
|
||||||
"args": {},
|
'args': {},
|
||||||
"timeout": 120
|
'timeout': 120,
|
||||||
}
|
}
|
||||||
|
|
||||||
await self.ap.provider_cfg.dump_config()
|
await self.ap.provider_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,25 +3,25 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("siliconflow-config", 23)
|
@migration.migration_class('siliconflow-config', 23)
|
||||||
class SiliconFlowConfigMigration(migration.Migration):
|
class SiliconFlowConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
async def need_migrate(self) -> bool:
|
async def need_migrate(self) -> bool:
|
||||||
"""判断当前环境是否需要运行此迁移"""
|
"""判断当前环境是否需要运行此迁移"""
|
||||||
|
|
||||||
return 'siliconflow-chat-completions' not in self.ap.provider_cfg.data['requester']
|
return (
|
||||||
|
'siliconflow-chat-completions' not in self.ap.provider_cfg.data['requester']
|
||||||
|
)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
self.ap.provider_cfg.data['keys']['siliconflow'] = [
|
self.ap.provider_cfg.data['keys']['siliconflow'] = ['xxxxxxx']
|
||||||
"xxxxxxx"
|
|
||||||
]
|
|
||||||
|
|
||||||
self.ap.provider_cfg.data['requester']['siliconflow-chat-completions'] = {
|
self.ap.provider_cfg.data['requester']['siliconflow-chat-completions'] = {
|
||||||
"base-url": "https://api.siliconflow.cn/v1",
|
'base-url': 'https://api.siliconflow.cn/v1',
|
||||||
"args": {},
|
'args': {},
|
||||||
"timeout": 120
|
'timeout': 120,
|
||||||
}
|
}
|
||||||
|
|
||||||
await self.ap.provider_cfg.dump_config()
|
await self.ap.provider_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("discord-config", 24)
|
@migration.migration_class('discord-config', 24)
|
||||||
class DiscordConfigMigration(migration.Migration):
|
class DiscordConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
@@ -19,11 +19,13 @@ class DiscordConfigMigration(migration.Migration):
|
|||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
self.ap.platform_cfg.data['platform-adapters'].append({
|
self.ap.platform_cfg.data['platform-adapters'].append(
|
||||||
"adapter": "discord",
|
{
|
||||||
"enable": False,
|
'adapter': 'discord',
|
||||||
"client_id": "1234567890",
|
'enable': False,
|
||||||
"token": "XXXXXXXXXX"
|
'client_id': '1234567890',
|
||||||
})
|
'token': 'XXXXXXXXXX',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
await self.ap.platform_cfg.dump_config()
|
await self.ap.platform_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("gewechat-config", 25)
|
@migration.migration_class('gewechat-config', 25)
|
||||||
class GewechatConfigMigration(migration.Migration):
|
class GewechatConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
@@ -19,15 +19,17 @@ class GewechatConfigMigration(migration.Migration):
|
|||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
self.ap.platform_cfg.data['platform-adapters'].append({
|
self.ap.platform_cfg.data['platform-adapters'].append(
|
||||||
"adapter": "gewechat",
|
{
|
||||||
"enable": False,
|
'adapter': 'gewechat',
|
||||||
"gewechat_url": "http://your-gewechat-server:2531",
|
'enable': False,
|
||||||
"gewechat_file_url": "http://your-gewechat-server:2532",
|
'gewechat_url': 'http://your-gewechat-server:2531',
|
||||||
"port": 2286,
|
'gewechat_file_url': 'http://your-gewechat-server:2532',
|
||||||
"callback_url": "http://your-callback-url:2286/gewechat/callback",
|
'port': 2286,
|
||||||
"app_id": "",
|
'callback_url': 'http://your-callback-url:2286/gewechat/callback',
|
||||||
"token": ""
|
'app_id': '',
|
||||||
})
|
'token': '',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
await self.ap.platform_cfg.dump_config()
|
await self.ap.platform_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("qqofficial-config", 26)
|
@migration.migration_class('qqofficial-config', 26)
|
||||||
class QQOfficialConfigMigration(migration.Migration):
|
class QQOfficialConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
@@ -19,13 +19,15 @@ class QQOfficialConfigMigration(migration.Migration):
|
|||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
self.ap.platform_cfg.data['platform-adapters'].append({
|
self.ap.platform_cfg.data['platform-adapters'].append(
|
||||||
"adapter": "qqofficial",
|
{
|
||||||
"enable": False,
|
'adapter': 'qqofficial',
|
||||||
"appid": "",
|
'enable': False,
|
||||||
"secret": "",
|
'appid': '',
|
||||||
"port": 2284,
|
'secret': '',
|
||||||
"token": ""
|
'port': 2284,
|
||||||
})
|
'token': '',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
await self.ap.platform_cfg.dump_config()
|
await self.ap.platform_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("wx-official-account-config", 27)
|
@migration.migration_class('wx-official-account-config', 27)
|
||||||
class WXOfficialAccountConfigMigration(migration.Migration):
|
class WXOfficialAccountConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
@@ -19,15 +19,17 @@ class WXOfficialAccountConfigMigration(migration.Migration):
|
|||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
self.ap.platform_cfg.data['platform-adapters'].append({
|
self.ap.platform_cfg.data['platform-adapters'].append(
|
||||||
"adapter": "officialaccount",
|
{
|
||||||
"enable": False,
|
'adapter': 'officialaccount',
|
||||||
"token": "",
|
'enable': False,
|
||||||
"EncodingAESKey": "",
|
'token': '',
|
||||||
"AppID": "",
|
'EncodingAESKey': '',
|
||||||
"AppSecret": "",
|
'AppID': '',
|
||||||
"host": "0.0.0.0",
|
'AppSecret': '',
|
||||||
"port": 2287
|
'host': '0.0.0.0',
|
||||||
})
|
'port': 2287,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
await self.ap.platform_cfg.dump_config()
|
await self.ap.platform_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("bailian-requester-config", 28)
|
@migration.migration_class('bailian-requester-config', 28)
|
||||||
class BailianRequesterConfigMigration(migration.Migration):
|
class BailianRequesterConfigMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
@@ -14,14 +14,12 @@ class BailianRequesterConfigMigration(migration.Migration):
|
|||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
self.ap.provider_cfg.data['keys']['bailian'] = [
|
self.ap.provider_cfg.data['keys']['bailian'] = ['sk-xxxxxxx']
|
||||||
"sk-xxxxxxx"
|
|
||||||
]
|
|
||||||
|
|
||||||
self.ap.provider_cfg.data['requester']['bailian-chat-completions'] = {
|
self.ap.provider_cfg.data['requester']['bailian-chat-completions'] = {
|
||||||
"base-url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
'base-url': 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||||
"args": {},
|
'args': {},
|
||||||
"timeout": 120
|
'timeout': 120,
|
||||||
}
|
}
|
||||||
|
|
||||||
await self.ap.provider_cfg.dump_config()
|
await self.ap.provider_cfg.dump_config()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from .. import migration
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
@migration.migration_class("dashscope-app-api-config", 29)
|
@migration.migration_class('dashscope-app-api-config', 29)
|
||||||
class DashscopeAppAPICfgMigration(migration.Migration):
|
class DashscopeAppAPICfgMigration(migration.Migration):
|
||||||
"""迁移"""
|
"""迁移"""
|
||||||
|
|
||||||
@@ -14,20 +14,14 @@ class DashscopeAppAPICfgMigration(migration.Migration):
|
|||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行迁移"""
|
"""执行迁移"""
|
||||||
self.ap.provider_cfg.data['dashscope-app-api'] = {
|
self.ap.provider_cfg.data['dashscope-app-api'] = {
|
||||||
"app-type": "agent",
|
'app-type': 'agent',
|
||||||
"api-key": "sk-1234567890",
|
'api-key': 'sk-1234567890',
|
||||||
"agent": {
|
'agent': {'app-id': 'Your_app_id', 'references_quote': '参考资料来自:'},
|
||||||
"app-id": "Your_app_id",
|
'workflow': {
|
||||||
"references_quote": "参考资料来自:"
|
'app-id': 'Your_app_id',
|
||||||
|
'references_quote': '参考资料来自:',
|
||||||
|
'biz_params': {'city': '北京', 'date': '2023-08-10'},
|
||||||
},
|
},
|
||||||
"workflow": {
|
|
||||||
"app-id": "Your_app_id",
|
|
||||||
"references_quote": "参考资料来自:",
|
|
||||||
"biz_params": {
|
|
||||||
"city": "北京",
|
|
||||||
"date": "2023-08-10"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
await self.ap.provider_cfg.dump_config()
|
await self.ap.provider_cfg.dump_config()
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user