refactor: 恢复所有审计API调用

This commit is contained in:
RockChinQ
2024-01-31 00:02:19 +08:00
parent c1c751a9ab
commit 32162afa65
11 changed files with 172 additions and 48 deletions
+2 -1
View File
@@ -38,6 +38,7 @@ class APIGroup(metaclass=abc.ABCMeta):
url = self.prefix + path url = self.prefix + path
data = json.dumps(data) data = json.dumps(data)
headers['Content-Type'] = 'application/json' headers['Content-Type'] = 'application/json'
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.request( async with session.request(
@@ -49,7 +50,7 @@ class APIGroup(metaclass=abc.ABCMeta):
**kwargs **kwargs
) as resp: ) as resp:
self.ap.logger.debug("data: %s", data) self.ap.logger.debug("data: %s", data)
self.ap.logger.debug("ret: %s", await resp.json()) self.ap.logger.debug("ret: %s", await resp.text())
except Exception as e: except Exception as e:
self.ap.logger.debug(f'上报失败: {e}') self.ap.logger.debug(f'上报失败: {e}')
+1 -1
View File
@@ -9,7 +9,7 @@ class V2MainDataAPI(apigroup.APIGroup):
def __init__(self, prefix: str, ap: app.Application): def __init__(self, prefix: str, ap: app.Application):
self.ap = ap self.ap = ap
super().__init__(prefix+"/usage", ap) super().__init__(prefix+"/main", ap)
async def do(self, *args, **kwargs): async def do(self, *args, **kwargs):
config = self.ap.cfg_mgr.data config = self.ap.cfg_mgr.data
+1 -1
View File
@@ -9,7 +9,7 @@ class V2PluginDataAPI(apigroup.APIGroup):
def __init__(self, prefix: str, ap: app.Application): def __init__(self, prefix: str, ap: app.Application):
self.ap = ap self.ap = ap
super().__init__(prefix+"/usage", ap) super().__init__(prefix+"/plugin", ap)
async def do(self, *args, **kwargs): async def do(self, *args, **kwargs):
config = self.ap.cfg_mgr.data config = self.ap.cfg_mgr.data
+5 -3
View File
@@ -167,7 +167,7 @@ class PluginDelOperator(operator.CommandOperator):
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: "+str(e))) yield entities.CommandReturn(error=errors.CommandError("插件删除失败: "+str(e)))
def update_plugin_status(plugin_name: str, new_status: bool, ap: app.Application): async def update_plugin_status(plugin_name: str, new_status: bool, ap: app.Application):
if ap.plugin_mgr.get_plugin_by_name(plugin_name) is not None: if ap.plugin_mgr.get_plugin_by_name(plugin_name) is not None:
for plugin in ap.plugin_mgr.plugins: for plugin in ap.plugin_mgr.plugins:
if plugin.plugin_name == plugin_name: if plugin.plugin_name == plugin_name:
@@ -176,6 +176,8 @@ def update_plugin_status(plugin_name: str, new_status: bool, ap: app.Application
for func in plugin.content_functions: for func in plugin.content_functions:
func.enable = new_status func.enable = new_status
await ap.plugin_mgr.setting.dump_container_setting(ap.plugin_mgr.plugins)
break break
return True return True
@@ -202,7 +204,7 @@ class PluginEnableOperator(operator.CommandOperator):
plugin_name = context.crt_params[0] plugin_name = context.crt_params[0]
try: try:
if update_plugin_status(plugin_name, True, self.ap): if await update_plugin_status(plugin_name, True, self.ap):
yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name)) yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name))
else: else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name))) yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
@@ -230,7 +232,7 @@ class PluginDisableOperator(operator.CommandOperator):
plugin_name = context.crt_params[0] plugin_name = context.crt_params[0]
try: try:
if update_plugin_status(plugin_name, False, self.ap): if await update_plugin_status(plugin_name, False, self.ap):
yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name)) yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name))
else: else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name))) yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
+25 -41
View File
@@ -35,7 +35,7 @@ async def make_app() -> app.Application:
print("以下文件不存在,已自动生成,请修改配置文件后重启:") print("以下文件不存在,已自动生成,请修改配置文件后重启:")
for file in generated_files: for file in generated_files:
print("-", file) print("-", file)
sys.exit(0) sys.exit(0)
missing_deps = await deps.check_deps() missing_deps = await deps.check_deps()
@@ -52,28 +52,24 @@ async def make_app() -> app.Application:
# 生成标识符 # 生成标识符
identifier.init() identifier.init()
cfg_mgr = await config.load_python_module_config( cfg_mgr = await config.load_python_module_config("config.py", "config-template.py")
"config.py",
"config-template.py"
)
cfg = cfg_mgr.data cfg = cfg_mgr.data
# 检查是否携带了 --override 或 -r 参数 # 检查是否携带了 --override 或 -r 参数
if '--override' in sys.argv or '-r' in sys.argv: if "--override" in sys.argv or "-r" in sys.argv:
use_override = True use_override = True
if use_override: if use_override:
overrided = await config.override_config_manager(cfg_mgr) overrided = await config.override_config_manager(cfg_mgr)
if overrided: if overrided:
qcg_logger.info("以下配置项已使用 override.json 覆盖:" + ",".join(overrided)) qcg_logger.info("以下配置项已使用 override.json 覆盖:" + ",".join(overrided))
tips_mgr = await config.load_python_module_config( tips_mgr = await config.load_python_module_config(
"tips.py", "tips.py", "tips-custom-template.py"
"tips-custom-template.py"
) )
# 检查管理员QQ号 # 检查管理员QQ号
if cfg_mgr.data['admin_qq'] == 0: if cfg_mgr.data["admin_qq"] == 0:
qcg_logger.warning("未设置管理员QQ号,将无法使用管理员命令,请在 config.py 中修改 admin_qq") qcg_logger.warning("未设置管理员QQ号,将无法使用管理员命令,请在 config.py 中修改 admin_qq")
# 构建组建实例 # 构建组建实例
@@ -85,50 +81,38 @@ async def make_app() -> app.Application:
proxy_mgr = proxy.ProxyManager(ap) proxy_mgr = proxy.ProxyManager(ap)
await proxy_mgr.initialize() await proxy_mgr.initialize()
ap.proxy_mgr = proxy_mgr ap.proxy_mgr = proxy_mgr
# 发送公告
ann_mgr = announce.AnnouncementManager(ap)
try:
announcements = await ann_mgr.fetch_new()
for ann in announcements:
ap.logger.info(f'[公告] {ann.time}: {ann.content}')
except Exception as e:
ap.logger.warning(f'获取公告时出错: {e}')
ap.query_pool = pool.QueryPool()
ver_mgr = version.VersionManager(ap) ver_mgr = version.VersionManager(ap)
await ver_mgr.initialize() await ver_mgr.initialize()
ap.ver_mgr = ver_mgr ap.ver_mgr = ver_mgr
try:
if await ap.ver_mgr.is_new_version_available():
ap.logger.info("有新版本可用,请使用 !update 命令更新")
except Exception as e:
ap.logger.warning(f"检查版本更新时出错: {e}")
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
await plugin_mgr_inst.initialize()
ap.plugin_mgr = plugin_mgr_inst
center_v2_api = center_v2.V2CenterAPI( center_v2_api = center_v2.V2CenterAPI(
ap, ap,
basic_info={ basic_info={
"host_id": identifier.identifier['host_id'], "host_id": identifier.identifier["host_id"],
"instance_id": identifier.identifier['instance_id'], "instance_id": identifier.identifier["instance_id"],
"semantic_version": ver_mgr.get_current_version(), "semantic_version": ver_mgr.get_current_version(),
"platform": sys.platform, "platform": sys.platform,
}, },
runtime_info={ runtime_info={
"admin_id": "{}".format(cfg['admin_qq']), "admin_id": "{}".format(cfg["admin_qq"]),
"msg_source": cfg['msg_source_adapter'], "msg_source": cfg["msg_source_adapter"],
} },
) )
ap.ctr_mgr = center_v2_api ap.ctr_mgr = center_v2_api
# 发送公告
ann_mgr = announce.AnnouncementManager(ap)
await ann_mgr.show_announcements()
ap.query_pool = pool.QueryPool()
await ap.ver_mgr.show_version_update()
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
await plugin_mgr_inst.initialize()
ap.plugin_mgr = plugin_mgr_inst
cmd_mgr_inst = cmdmgr.CommandManager(ap) cmd_mgr_inst = cmdmgr.CommandManager(ap)
await cmd_mgr_inst.initialize() await cmd_mgr_inst.initialize()
ap.cmd_mgr = cmd_mgr_inst ap.cmd_mgr = cmd_mgr_inst
@@ -159,7 +143,7 @@ async def make_app() -> app.Application:
ctrl = controller.Controller(ap) ctrl = controller.Controller(ap)
ap.ctrl = ctrl ap.ctrl = ctrl
await ap.initialize() await ap.initialize()
return ap return ap
+18
View File
@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import typing import typing
import time
import mirai import mirai
@@ -84,9 +85,16 @@ class ChatMessageHandler(handler.MessageHandler):
called_functions = [] called_functions = []
text_length = 0
start_time = time.time()
async for result in conversation.use_model.requester.request(query, conversation): async for result in conversation.use_model.requester.request(query, conversation):
conversation.messages.append(result) conversation.messages.append(result)
if result.content is not None:
text_length += len(result.content)
# 转换成可读消息 # 转换成可读消息
if result.role == 'assistant': if result.role == 'assistant':
@@ -172,3 +180,13 @@ class ChatMessageHandler(handler.MessageHandler):
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )
await self.ap.ctr_mgr.usage.post_query_record(
session_type=session.launcher_type.value,
session_id=str(session.launcher_id),
query_ability_provider="QChatGPT.Chat",
usage=text_length,
model_name=conversation.use_model.name,
response_seconds=int(time.time() - start_time),
retry_times=-1,
)
+52
View File
@@ -64,6 +64,15 @@ class PluginManager:
""" """
await self.installer.install_plugin(plugin_source) await self.installer.install_plugin(plugin_source)
await self.ap.ctr_mgr.plugin.post_install_record(
{
"name": "unknown",
"remote": plugin_source,
"author": "unknown",
"version": "HEAD"
}
)
async def uninstall_plugin( async def uninstall_plugin(
self, self,
plugin_name: str, plugin_name: str,
@@ -72,6 +81,17 @@ class PluginManager:
""" """
await self.installer.uninstall_plugin(plugin_name) await self.installer.uninstall_plugin(plugin_name)
plugin_container = self.get_plugin_by_name(plugin_name)
await self.ap.ctr_mgr.plugin.post_remove_record(
{
"name": plugin_name,
"remote": plugin_container.plugin_source,
"author": plugin_container.plugin_author,
"version": plugin_container.plugin_version
}
)
async def update_plugin( async def update_plugin(
self, self,
plugin_name: str, plugin_name: str,
@@ -80,6 +100,19 @@ class PluginManager:
"""更新插件 """更新插件
""" """
await self.installer.update_plugin(plugin_name, plugin_source) await self.installer.update_plugin(plugin_name, plugin_source)
plugin_container = self.get_plugin_by_name(plugin_name)
await self.ap.ctr_mgr.plugin.post_update_record(
plugin={
"name": plugin_name,
"remote": plugin_container.plugin_source,
"author": plugin_container.plugin_author,
"version": plugin_container.plugin_version
},
old_version=plugin_container.plugin_version,
new_version="HEAD"
)
def get_plugin_by_name(self, plugin_name: str) -> context.RuntimeContainer: def get_plugin_by_name(self, plugin_name: str) -> context.RuntimeContainer:
"""通过插件名获取插件 """通过插件名获取插件
@@ -98,10 +131,14 @@ class PluginManager:
event=event event=event
) )
emitted_plugins: list[context.RuntimeContainer] = []
for plugin in self.plugins: for plugin in self.plugins:
if plugin.enabled: if plugin.enabled:
if event.__class__ in plugin.event_handlers: if event.__class__ in plugin.event_handlers:
emitted_plugins.append(plugin)
is_prevented_default_before_call = ctx.is_prevented_default() is_prevented_default_before_call = ctx.is_prevented_default()
try: try:
@@ -126,4 +163,19 @@ class PluginManager:
self.ap.logger.debug(f'事件 {event.__class__.__name__}({ctx.eid}) 处理完成,返回值 {ctx.__return_value__}') self.ap.logger.debug(f'事件 {event.__class__.__name__}({ctx.eid}) 处理完成,返回值 {ctx.__return_value__}')
if emitted_plugins:
plugins_info: list[dict] = [
{
'name': plugin.plugin_name,
'remote': plugin.plugin_source,
'version': plugin.plugin_version,
'author': plugin.plugin_author
} for plugin in emitted_plugins
]
await self.ap.ctr_mgr.usage.post_event_record(
plugins=plugins_info,
event_name=event.__class__.__name__
)
return ctx return ctx
+19
View File
@@ -59,6 +59,25 @@ class SettingManager:
await self.settings.dump_config() await self.settings.dump_config()
async def dump_container_setting(
self,
plugin_containers: list[context.RuntimeContainer]
):
"""保存插件容器设置
"""
for plugin in plugin_containers:
for ps in self.settings.data['plugins']:
if ps['name'] == plugin.plugin_name:
plugin_dict = plugin.to_setting_dict()
for key in plugin_dict:
ps[key] = plugin_dict[key]
break
await self.settings.dump_config()
async def record_installed_plugin_source( async def record_installed_plugin_source(
self, self,
pkg_path: str, pkg_path: str,
+21
View File
@@ -84,3 +84,24 @@ class ToolManager:
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}') self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}')
traceback.print_exc() traceback.print_exc()
return f'error occurred when executing function {name}: {e}' return f'error occurred when executing function {name}: {e}'
finally:
plugin = None
for p in self.ap.plugin_mgr.plugins:
if function in p.content_functions:
plugin = p
break
if plugin is not None:
await self.ap.ctr_mgr.usage.post_function_record(
plugin={
'name': plugin.plugin_name,
'remote': plugin.plugin_source,
'version': plugin.plugin_version,
'author': plugin.plugin_author
},
function_name=function.name,
function_description=function.description,
)
+17
View File
@@ -104,3 +104,20 @@ class AnnouncementManager:
await self.write_saved(all) await self.write_saved(all)
return to_show return to_show
async def show_announcements(
self
):
"""显示公告"""
try:
announcements = await self.fetch_new()
for ann in announcements:
self.ap.logger.info(f'[公告] {ann.time}: {ann.content}')
if announcements:
await self.ap.ctr_mgr.main.post_announcement_showed(
ids=[item.id for item in announcements]
)
except Exception as e:
self.ap.logger.warning(f'获取公告时出错: {e}')
+11 -1
View File
@@ -148,7 +148,7 @@ class VersionManager:
with open("current_tag", "w") as f: with open("current_tag", "w") as f:
f.write(current_tag) f.write(current_tag)
self.ap.ctr_mgr.main.post_update_record( await self.ap.ctr_mgr.main.post_update_record(
spent_seconds=int(time.time()-start_time), spent_seconds=int(time.time()-start_time),
infer_reason="update", infer_reason="update",
old_version=old_tag, old_version=old_tag,
@@ -224,3 +224,13 @@ class VersionManager:
return 0 return 0
async def show_version_update(
self
):
try:
if await self.ap.ver_mgr.is_new_version_available():
self.ap.logger.info("有新版本可用,请使用 !update 命令更新")
except Exception as e:
self.ap.logger.warning(f"检查版本更新时出错: {e}")