refactor: 审计api改为异步

This commit is contained in:
RockChinQ
2024-01-29 21:58:47 +08:00
parent 13ab647dc0
commit 3945ac95d1
10 changed files with 62 additions and 80 deletions
@@ -2,10 +2,13 @@ import abc
import uuid import uuid
import json import json
import logging import logging
import threading import asyncio
import aiohttp
import requests import requests
from ...core import app
class APIGroup(metaclass=abc.ABCMeta): class APIGroup(metaclass=abc.ABCMeta):
"""API 组抽象类""" """API 组抽象类"""
@@ -14,10 +17,13 @@ class APIGroup(metaclass=abc.ABCMeta):
prefix = None prefix = None
def __init__(self, prefix: str): ap: app.Application
self.prefix = prefix
def do( def __init__(self, prefix: str, ap: app.Application):
self.prefix = prefix
self.ap = ap
async def _do(
self, self,
method: str, method: str,
path: str, path: str,
@@ -26,47 +32,38 @@ class APIGroup(metaclass=abc.ABCMeta):
headers: dict = {}, headers: dict = {},
**kwargs **kwargs
): ):
"""执行一个请求"""
def thr_wrapper( url = self.prefix + path
self, data = json.dumps(data)
method: str, headers['Content-Type'] = 'application/json'
path: str, try:
data: dict = None, async with aiohttp.ClientSession() as session:
params: dict = None, async with session.request(
headers: dict = {},
**kwargs
):
try:
url = self.prefix + path
data = json.dumps(data)
headers['Content-Type'] = 'application/json'
ret = requests.request(
method, method,
url, url,
data=data, data=data,
params=params, params=params,
headers=headers, headers=headers,
**kwargs **kwargs
) ) as resp:
self.ap.logger.debug("data: %s", data)
self.ap.logger.debug("ret: %s", await resp.json())
logging.debug("data: %s", data) except Exception as e:
self.ap.logger.debug(f'上报失败: {e}')
async def do(
self,
method: str,
path: str,
data: dict = None,
params: dict = None,
headers: dict = {},
**kwargs
) -> asyncio.Task:
"""执行请求"""
asyncio.create_task(self._do(method, path, data, params, headers, **kwargs))
logging.debug("ret: %s", ret.json())
except Exception as e:
logging.debug("上报数据失败: %s", e)
thr = threading.Thread(target=thr_wrapper, args=(
self,
method,
path,
data,
params,
headers,
), kwargs=kwargs)
thr.start()
def gen_rid( def gen_rid(
self self
): ):
@@ -7,19 +7,17 @@ from ....core import app
class V2MainDataAPI(apigroup.APIGroup): class V2MainDataAPI(apigroup.APIGroup):
"""主程序相关 数据API""" """主程序相关 数据API"""
ap: app.Application
def __init__(self, prefix: str, ap: app.Application): def __init__(self, prefix: str, ap: app.Application):
self.ap = ap self.ap = ap
super().__init__(prefix+"/usage") super().__init__(prefix+"/usage", ap)
def do(self, *args, **kwargs): async def do(self, *args, **kwargs):
config = self.ap.cfg_mgr.data config = self.ap.cfg_mgr.data
if not config['report_usage']: if not config['report_usage']:
return None return None
return super().do(*args, **kwargs) return await super().do(*args, **kwargs)
def post_update_record( async def post_update_record(
self, self,
spent_seconds: int, spent_seconds: int,
infer_reason: str, infer_reason: str,
@@ -27,7 +25,7 @@ class V2MainDataAPI(apigroup.APIGroup):
new_version: str, new_version: str,
): ):
"""提交更新记录""" """提交更新记录"""
return self.do( return await self.do(
"POST", "POST",
"/update", "/update",
data={ data={
@@ -41,12 +39,12 @@ class V2MainDataAPI(apigroup.APIGroup):
} }
) )
def post_announcement_showed( async def post_announcement_showed(
self, self,
ids: list[int], ids: list[int],
): ):
"""提交公告已阅""" """提交公告已阅"""
return self.do( return await self.do(
"POST", "POST",
"/announcement", "/announcement",
data={ data={
@@ -7,24 +7,22 @@ from .. import apigroup
class V2PluginDataAPI(apigroup.APIGroup): class V2PluginDataAPI(apigroup.APIGroup):
"""插件数据相关 API""" """插件数据相关 API"""
ap: app.Application
def __init__(self, prefix: str, ap: app.Application): def __init__(self, prefix: str, ap: app.Application):
self.ap = ap self.ap = ap
super().__init__(prefix+"/usage") super().__init__(prefix+"/usage", ap)
def do(self, *args, **kwargs): async def do(self, *args, **kwargs):
config = self.ap.cfg_mgr.data config = self.ap.cfg_mgr.data
if not config['report_usage']: if not config['report_usage']:
return None return None
return super().do(*args, **kwargs) return await super().do(*args, **kwargs)
def post_install_record( async def post_install_record(
self, self,
plugin: dict plugin: dict
): ):
"""提交插件安装记录""" """提交插件安装记录"""
return self.do( return await self.do(
"POST", "POST",
"/install", "/install",
data={ data={
@@ -33,12 +31,12 @@ class V2PluginDataAPI(apigroup.APIGroup):
} }
) )
def post_remove_record( async def post_remove_record(
self, self,
plugin: dict plugin: dict
): ):
"""提交插件卸载记录""" """提交插件卸载记录"""
return self.do( return await self.do(
"POST", "POST",
"/remove", "/remove",
data={ data={
@@ -47,14 +45,14 @@ class V2PluginDataAPI(apigroup.APIGroup):
} }
) )
def post_update_record( async def post_update_record(
self, self,
plugin: dict, plugin: dict,
old_version: str, old_version: str,
new_version: str, new_version: str,
): ):
"""提交插件更新记录""" """提交插件更新记录"""
return self.do( return await self.do(
"POST", "POST",
"/update", "/update",
data={ data={
@@ -7,19 +7,17 @@ from ....core import app
class V2UsageDataAPI(apigroup.APIGroup): class V2UsageDataAPI(apigroup.APIGroup):
"""使用量数据相关 API""" """使用量数据相关 API"""
ap: app.Application
def __init__(self, prefix: str, ap: app.Application): def __init__(self, prefix: str, ap: app.Application):
self.ap = ap self.ap = ap
super().__init__(prefix+"/usage") super().__init__(prefix+"/usage", ap)
def do(self, *args, **kwargs): async def do(self, *args, **kwargs):
config = self.ap.cfg_mgr.data config = self.ap.cfg_mgr.data
if not config['report_usage']: if not config['report_usage']:
return None return None
return super().do(*args, **kwargs) return await super().do(*args, **kwargs)
def post_query_record( async def post_query_record(
self, self,
session_type: str, session_type: str,
session_id: str, session_id: str,
@@ -30,7 +28,7 @@ class V2UsageDataAPI(apigroup.APIGroup):
retry_times: int, retry_times: int,
): ):
"""提交请求记录""" """提交请求记录"""
return self.do( return await self.do(
"POST", "POST",
"/query", "/query",
data={ data={
@@ -50,13 +48,13 @@ class V2UsageDataAPI(apigroup.APIGroup):
} }
) )
def post_event_record( async def post_event_record(
self, self,
plugins: list[dict], plugins: list[dict],
event_name: str, event_name: str,
): ):
"""提交事件触发记录""" """提交事件触发记录"""
return self.do( return await self.do(
"POST", "POST",
"/event", "/event",
data={ data={
@@ -69,14 +67,14 @@ class V2UsageDataAPI(apigroup.APIGroup):
} }
) )
def post_function_record( async def post_function_record(
self, self,
plugin: dict, plugin: dict,
function_name: str, function_name: str,
function_description: str, function_description: str,
): ):
"""提交内容函数使用记录""" """提交内容函数使用记录"""
return self.do( return await self.do(
"POST", "POST",
"/function", "/function",
data={ data={
-9
View File
@@ -56,15 +56,6 @@ class Application:
async def initialize(self): async def initialize(self):
pass pass
# 把现有的所有内容函数加到toolmgr里
# for func in plugin_host.__callable_functions__:
# self.tool_mgr.register_legacy_function(
# name=func['name'],
# description=func['description'],
# parameters=func['parameters'],
# func=plugin_host.__function_inst_map__[func['name']]
# )
async def run(self): async def run(self):
await self.plugin_mgr.load_plugins() await self.plugin_mgr.load_plugins()
await self.plugin_mgr.initialize_plugins() await self.plugin_mgr.initialize_plugins()
+1 -1
View File
@@ -20,7 +20,7 @@ from ..provider.tools import toolmgr as llm_tool_mgr
from ..platform import manager as im_mgr from ..platform import manager as im_mgr
from ..command import cmdmgr from ..command import cmdmgr
from ..plugin import manager as plugin_mgr from ..plugin import manager as plugin_mgr
from ..utils.center import v2 as center_v2 from ..audit.center import v2 as center_v2
from ..utils import version, proxy from ..utils import version, proxy
use_override = False use_override = False
+1 -1
View File
@@ -54,7 +54,7 @@ class QQBotManager:
self.bot_account_id = self.adapter.bot_account_id self.bot_account_id = self.adapter.bot_account_id
# 保存 account_id 到审计模块 # 保存 account_id 到审计模块
from ..utils.center import apigroup from ..audit.center import apigroup
apigroup.APIGroup._runtime_info['account_id'] = "{}".format(self.bot_account_id) apigroup.APIGroup._runtime_info['account_id'] = "{}".format(self.bot_account_id)
async def on_friend_message(event: FriendMessage): async def on_friend_message(event: FriendMessage):