refactor: 审计api改为异步

This commit is contained in:
RockChinQ
2024-01-29 21:58:47 +08:00
parent 13ab647dc0
commit 3945ac95d1
10 changed files with 62 additions and 80 deletions

View File

@@ -2,10 +2,13 @@ import abc
import uuid
import json
import logging
import threading
import asyncio
import aiohttp
import requests
from ...core import app
class APIGroup(metaclass=abc.ABCMeta):
"""API 组抽象类"""
@@ -14,10 +17,13 @@ class APIGroup(metaclass=abc.ABCMeta):
prefix = None
def __init__(self, prefix: str):
self.prefix = prefix
ap: app.Application
def do(
def __init__(self, prefix: str, ap: app.Application):
self.prefix = prefix
self.ap = ap
async def _do(
self,
method: str,
path: str,
@@ -26,47 +32,38 @@ class APIGroup(metaclass=abc.ABCMeta):
headers: dict = {},
**kwargs
):
"""执行一个请求"""
def thr_wrapper(
self,
method: str,
path: str,
data: dict = None,
params: dict = None,
headers: dict = {},
**kwargs
):
try:
url = self.prefix + path
data = json.dumps(data)
headers['Content-Type'] = 'application/json'
ret = requests.request(
url = self.prefix + path
data = json.dumps(data)
headers['Content-Type'] = 'application/json'
try:
async with aiohttp.ClientSession() as session:
async with session.request(
method,
url,
data=data,
params=params,
headers=headers,
**kwargs
)
) as resp:
self.ap.logger.debug("data: %s", data)
self.ap.logger.debug("ret: %s", await resp.json())
logging.debug("data: %s", data)
except Exception as e:
self.ap.logger.debug(f'上报失败: {e}')
async def do(
self,
method: str,
path: str,
data: dict = None,
params: dict = None,
headers: dict = {},
**kwargs
) -> asyncio.Task:
"""执行请求"""
asyncio.create_task(self._do(method, path, data, params, headers, **kwargs))
logging.debug("ret: %s", ret.json())
except Exception as e:
logging.debug("上报数据失败: %s", e)
thr = threading.Thread(target=thr_wrapper, args=(
self,
method,
path,
data,
params,
headers,
), kwargs=kwargs)
thr.start()
def gen_rid(
self
):

View File

@@ -7,19 +7,17 @@ from ....core import app
class V2MainDataAPI(apigroup.APIGroup):
"""主程序相关 数据API"""
ap: app.Application
def __init__(self, prefix: str, ap: app.Application):
self.ap = ap
super().__init__(prefix+"/usage")
super().__init__(prefix+"/usage", ap)
def do(self, *args, **kwargs):
async def do(self, *args, **kwargs):
config = self.ap.cfg_mgr.data
if not config['report_usage']:
return None
return super().do(*args, **kwargs)
return await super().do(*args, **kwargs)
def post_update_record(
async def post_update_record(
self,
spent_seconds: int,
infer_reason: str,
@@ -27,7 +25,7 @@ class V2MainDataAPI(apigroup.APIGroup):
new_version: str,
):
"""提交更新记录"""
return self.do(
return await self.do(
"POST",
"/update",
data={
@@ -41,12 +39,12 @@ class V2MainDataAPI(apigroup.APIGroup):
}
)
def post_announcement_showed(
async def post_announcement_showed(
self,
ids: list[int],
):
"""提交公告已阅"""
return self.do(
return await self.do(
"POST",
"/announcement",
data={

View File

@@ -7,24 +7,22 @@ from .. import apigroup
class V2PluginDataAPI(apigroup.APIGroup):
"""插件数据相关 API"""
ap: app.Application
def __init__(self, prefix: str, ap: app.Application):
self.ap = ap
super().__init__(prefix+"/usage")
super().__init__(prefix+"/usage", ap)
def do(self, *args, **kwargs):
async def do(self, *args, **kwargs):
config = self.ap.cfg_mgr.data
if not config['report_usage']:
return None
return super().do(*args, **kwargs)
return await super().do(*args, **kwargs)
def post_install_record(
async def post_install_record(
self,
plugin: dict
):
"""提交插件安装记录"""
return self.do(
return await self.do(
"POST",
"/install",
data={
@@ -33,12 +31,12 @@ class V2PluginDataAPI(apigroup.APIGroup):
}
)
def post_remove_record(
async def post_remove_record(
self,
plugin: dict
):
"""提交插件卸载记录"""
return self.do(
return await self.do(
"POST",
"/remove",
data={
@@ -47,14 +45,14 @@ class V2PluginDataAPI(apigroup.APIGroup):
}
)
def post_update_record(
async def post_update_record(
self,
plugin: dict,
old_version: str,
new_version: str,
):
"""提交插件更新记录"""
return self.do(
return await self.do(
"POST",
"/update",
data={

View File

@@ -7,19 +7,17 @@ from ....core import app
class V2UsageDataAPI(apigroup.APIGroup):
"""使用量数据相关 API"""
ap: app.Application
def __init__(self, prefix: str, ap: app.Application):
self.ap = ap
super().__init__(prefix+"/usage")
super().__init__(prefix+"/usage", ap)
def do(self, *args, **kwargs):
async def do(self, *args, **kwargs):
config = self.ap.cfg_mgr.data
if not config['report_usage']:
return None
return super().do(*args, **kwargs)
def post_query_record(
return await super().do(*args, **kwargs)
async def post_query_record(
self,
session_type: str,
session_id: str,
@@ -30,7 +28,7 @@ class V2UsageDataAPI(apigroup.APIGroup):
retry_times: int,
):
"""提交请求记录"""
return self.do(
return await self.do(
"POST",
"/query",
data={
@@ -50,13 +48,13 @@ class V2UsageDataAPI(apigroup.APIGroup):
}
)
def post_event_record(
async def post_event_record(
self,
plugins: list[dict],
event_name: str,
):
"""提交事件触发记录"""
return self.do(
return await self.do(
"POST",
"/event",
data={
@@ -69,14 +67,14 @@ class V2UsageDataAPI(apigroup.APIGroup):
}
)
def post_function_record(
async def post_function_record(
self,
plugin: dict,
function_name: str,
function_description: str,
):
"""提交内容函数使用记录"""
return self.do(
return await self.do(
"POST",
"/function",
data={

View File

@@ -56,15 +56,6 @@ class Application:
async def initialize(self):
pass
# 把现有的所有内容函数加到toolmgr里
# for func in plugin_host.__callable_functions__:
# self.tool_mgr.register_legacy_function(
# name=func['name'],
# description=func['description'],
# parameters=func['parameters'],
# func=plugin_host.__function_inst_map__[func['name']]
# )
async def run(self):
await self.plugin_mgr.load_plugins()
await self.plugin_mgr.initialize_plugins()

View File

@@ -20,7 +20,7 @@ from ..provider.tools import toolmgr as llm_tool_mgr
from ..platform import manager as im_mgr
from ..command import cmdmgr
from ..plugin import manager as plugin_mgr
from ..utils.center import v2 as center_v2
from ..audit.center import v2 as center_v2
from ..utils import version, proxy
use_override = False

View File

@@ -54,7 +54,7 @@ class QQBotManager:
self.bot_account_id = self.adapter.bot_account_id
# 保存 account_id 到审计模块
from ..utils.center import apigroup
from ..audit.center import apigroup
apigroup.APIGroup._runtime_info['account_id'] = "{}".format(self.bot_account_id)
async def on_friend_message(event: FriendMessage):