From 6cc468866045c69ad2e48db7b0fb7222912a5f41 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Mon, 29 Jan 2024 21:22:27 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E6=8F=92?= =?UTF-8?q?=E4=BB=B6=E7=B3=BB=E7=BB=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/audit/gatherer.py | 114 ---- pkg/command/cmdmgr.py | 3 +- pkg/command/entities.py | 3 +- pkg/command/operator.py | 1 - pkg/command/operators/func.py | 6 +- pkg/command/operators/plugin.py | 54 +- pkg/command/operators/update.py | 3 +- pkg/command/operators/version.py | 5 +- pkg/config/impls/json.py | 3 + pkg/core/app.py | 35 +- pkg/core/boot.py | 29 +- pkg/core/entities.py | 47 ++ pkg/platform/manager.py | 7 +- pkg/plugin/context.py | 207 +++++++ pkg/plugin/errors.py | 24 + pkg/plugin/events.py | 96 +++ pkg/plugin/host.py | 581 +----------------- pkg/plugin/installer.py | 45 ++ .../api => plugin/installers}/__init__.py | 0 pkg/plugin/installers/github.py | 137 +++++ pkg/plugin/loader.py | 25 + pkg/plugin/loaders/__init__.py | 0 pkg/plugin/loaders/legacy.py | 155 +++++ pkg/plugin/manager.py | 112 ++++ pkg/plugin/metadata.py | 87 --- pkg/plugin/models.py | 300 +-------- pkg/plugin/setting.py | 83 +++ pkg/plugin/settings.py | 103 ---- pkg/plugin/switch.py | 94 --- pkg/provider/api/chat_completion.py | 232 ------- pkg/provider/api/completion.py | 100 --- pkg/provider/api/model.py | 40 -- pkg/provider/requester/api.py | 3 +- pkg/provider/requester/apis/chatcmpl.py | 38 +- pkg/provider/requester/entities.py | 8 +- pkg/provider/requester/modelmgr.py | 22 +- pkg/provider/requester/tokenizer.py | 29 + pkg/provider/requester/tokenizers/__init__.py | 0 pkg/provider/requester/tokenizers/tiktoken.py | 28 + pkg/provider/session/entities.py | 53 -- pkg/provider/session/sessionmgr.py | 11 +- pkg/provider/tools/entities.py | 2 + pkg/provider/tools/toolmgr.py | 34 +- pkg/utils/center/groups/main.py | 11 +- pkg/utils/center/groups/plugin.py | 11 +- pkg/utils/center/groups/usage.py | 9 +- pkg/utils/center/v2.py | 13 +- pkg/utils/network.py | 11 - pkg/utils/pkgmgr.py | 10 +- pkg/utils/proxy.py | 30 + pkg/utils/updater.py | 113 ---- pkg/utils/version.py | 130 ++++ res/templates/plugin-setting-template.json | 3 + 53 files changed, 1307 insertions(+), 1993 deletions(-) delete mode 100644 pkg/audit/gatherer.py create mode 100644 pkg/plugin/context.py create mode 100644 pkg/plugin/errors.py create mode 100644 pkg/plugin/events.py create mode 100644 pkg/plugin/installer.py rename pkg/{provider/api => plugin/installers}/__init__.py (100%) create mode 100644 pkg/plugin/installers/github.py create mode 100644 pkg/plugin/loader.py create mode 100644 pkg/plugin/loaders/__init__.py create mode 100644 pkg/plugin/loaders/legacy.py create mode 100644 pkg/plugin/manager.py delete mode 100644 pkg/plugin/metadata.py create mode 100644 pkg/plugin/setting.py delete mode 100644 pkg/plugin/settings.py delete mode 100644 pkg/plugin/switch.py delete mode 100644 pkg/provider/api/chat_completion.py delete mode 100644 pkg/provider/api/completion.py delete mode 100644 pkg/provider/api/model.py create mode 100644 pkg/provider/requester/tokenizer.py create mode 100644 pkg/provider/requester/tokenizers/__init__.py create mode 100644 pkg/provider/requester/tokenizers/tiktoken.py delete mode 100644 pkg/provider/session/entities.py delete mode 100644 pkg/utils/network.py create mode 100644 pkg/utils/proxy.py create mode 100644 pkg/utils/version.py create mode 100644 res/templates/plugin-setting-template.json diff --git a/pkg/audit/gatherer.py b/pkg/audit/gatherer.py deleted file mode 100644 index 01bb7f2d..00000000 --- a/pkg/audit/gatherer.py +++ /dev/null @@ -1,114 +0,0 @@ -""" -使用量统计以及数据上报功能实现 -""" - -import hashlib -import json -import logging -import threading - -import requests - -from ..utils import context -from ..utils import updater - - -class DataGatherer: - """数据收集器""" - - usage = {} - """各api-key的使用量 - - 以key值md5为key,{ - "text": { - "gpt-3.5-turbo": 文字量:int, - }, - "image": { - "256x256": 图片数量:int, - } - }为值的字典""" - - version_str = "undetermined" - - def __init__(self): - self.load_from_db() - try: - self.version_str = updater.get_current_tag() # 从updater模块获取版本号 - except: - pass - - def get_usage(self, key_md5): - return self.usage[key_md5] if key_md5 in self.usage else {} - - def report_text_model_usage(self, model, total_tokens): - """调用方报告文字模型请求文字使用量""" - - key_md5 = context.get_openai_manager().key_mgr.get_using_key_md5() # 以key的md5进行储存 - - if key_md5 not in self.usage: - self.usage[key_md5] = {} - - if "text" not in self.usage[key_md5]: - self.usage[key_md5]["text"] = {} - - if model not in self.usage[key_md5]["text"]: - self.usage[key_md5]["text"][model] = 0 - - length = total_tokens - self.usage[key_md5]["text"][model] += length - self.dump_to_db() - - def report_image_model_usage(self, size): - """调用方报告图片模型请求图片使用量""" - - key_md5 = context.get_openai_manager().key_mgr.get_using_key_md5() - - if key_md5 not in self.usage: - self.usage[key_md5] = {} - - if "image" not in self.usage[key_md5]: - self.usage[key_md5]["image"] = {} - - if size not in self.usage[key_md5]["image"]: - self.usage[key_md5]["image"][size] = 0 - - self.usage[key_md5]["image"][size] += 1 - self.dump_to_db() - - def get_text_length_of_key(self, key): - """获取指定api-key (明文) 的文字总使用量(本地记录)""" - key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest() - if key_md5 not in self.usage: - return 0 - if "text" not in self.usage[key_md5]: - return 0 - # 遍历其中所有模型,求和 - return sum(self.usage[key_md5]["text"].values()) - - def get_image_count_of_key(self, key): - """获取指定api-key (明文) 的图片总使用量(本地记录)""" - - key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest() - if key_md5 not in self.usage: - return 0 - if "image" not in self.usage[key_md5]: - return 0 - # 遍历其中所有模型,求和 - return sum(self.usage[key_md5]["image"].values()) - - def get_total_text_length(self): - """获取所有api-key的文字总使用量(本地记录)""" - total = 0 - for key in self.usage: - if "text" not in self.usage[key]: - continue - total += sum(self.usage[key]["text"].values()) - return total - - def dump_to_db(self): - context.get_database_manager().dump_usage_json(self.usage) - - def load_from_db(self): - json_str = context.get_database_manager().load_usage_json() - if json_str is not None: - self.usage = json.loads(json_str) diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py index 530a717a..cf2dbc88 100644 --- a/pkg/command/cmdmgr.py +++ b/pkg/command/cmdmgr.py @@ -4,7 +4,6 @@ import typing from ..core import app, entities as core_entities from ..provider import entities as llm_entities -from ..provider.session import entities as session_entities from . import entities, operator, errors from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cfg, cmd, help, version, update @@ -80,7 +79,7 @@ class CommandManager: self, command_text: str, query: core_entities.Query, - session: session_entities.Session + session: core_entities.Session ) -> typing.AsyncGenerator[entities.CommandReturn, None]: """执行命令 """ diff --git a/pkg/command/entities.py b/pkg/command/entities.py index 7e6ff549..f5f8bef5 100644 --- a/pkg/command/entities.py +++ b/pkg/command/entities.py @@ -6,7 +6,6 @@ import pydantic import mirai from ..core import app, entities as core_entities -from ..provider.session import entities as session_entities from . import errors, operator @@ -28,7 +27,7 @@ class ExecuteContext(pydantic.BaseModel): query: core_entities.Query - session: session_entities.Session + session: core_entities.Session command_text: str diff --git a/pkg/command/operator.py b/pkg/command/operator.py index c5529cef..1b29e6c0 100644 --- a/pkg/command/operator.py +++ b/pkg/command/operator.py @@ -4,7 +4,6 @@ import typing import abc from ..core import app, entities as core_entities -from ..provider.session import entities as session_entities from . import entities diff --git a/pkg/command/operators/func.py b/pkg/command/operators/func.py index a4e81c35..33031bfb 100644 --- a/pkg/command/operators/func.py +++ b/pkg/command/operators/func.py @@ -2,7 +2,6 @@ from __future__ import annotations from typing import AsyncGenerator from .. import operator, entities, cmdmgr -from ...plugin import host as plugin_host @operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func') @@ -13,7 +12,10 @@ class FuncOperator(operator.CommandOperator): reply_str = "当前已加载的内容函数: \n\n" index = 1 - for func in self.ap.tool_mgr.all_functions: + + all_functions = await self.ap.tool_mgr.get_all_functions() + + for func in all_functions: reply_str += "{}. {}{}:\n{}\n\n".format( index, ("(已禁用) " if not func.enable else ""), diff --git a/pkg/command/operators/plugin.py b/pkg/command/operators/plugin.py index 195852ae..27973913 100644 --- a/pkg/command/operators/plugin.py +++ b/pkg/command/operators/plugin.py @@ -3,8 +3,6 @@ import typing import traceback from .. import operator, entities, cmdmgr, errors -from ...plugin import host as plugin_host -from ...utils import updater from ...core import app @@ -20,16 +18,15 @@ class PluginOperator(operator.CommandOperator): context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - plugin_list = plugin_host.__plugins__ - reply_str = "所有插件({}):\n".format(len(plugin_host.__plugins__)) + plugin_list = self.ap.plugin_mgr.plugins + reply_str = "所有插件({}):\n".format(len(plugin_list)) idx = 0 - for key in plugin_host.iter_plugins_name(): - plugin = plugin_list[key] + for plugin in plugin_list: reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\ - .format((idx+1), plugin['name'], - "[已禁用]" if not plugin['enabled'] else "", - plugin['description'], - plugin['version'], plugin['author']) + .format((idx+1), plugin.plugin_name, + "[已禁用]" if not plugin.enabled else "", + plugin.plugin_description, + plugin.plugin_version, plugin.plugin_author) # TODO 从元数据调远程地址 # if updater.is_repo("/".join(plugin['path'].split('/')[:-1])): @@ -63,7 +60,7 @@ class PluginGetOperator(operator.CommandOperator): yield entities.CommandReturn(text="正在安装插件...") try: - plugin_host.install_plugin(repo) + await self.ap.plugin_mgr.install_plugin(repo) yield entities.CommandReturn(text="插件安装成功,请重启程序以加载插件") except Exception as e: traceback.print_exc() @@ -89,11 +86,11 @@ class PluginUpdateOperator(operator.CommandOperator): plugin_name = context.crt_params[0] try: - plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(plugin_name) + plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) - if plugin_path_name is not None: + if plugin_container is not None: yield entities.CommandReturn(text="正在更新插件...") - plugin_host.update_plugin(plugin_name) + await self.ap.plugin_mgr.update_plugin(plugin_name) yield entities.CommandReturn(text="插件更新成功,请重启程序以加载插件") else: yield entities.CommandReturn(error=errors.CommandError("插件更新失败: 未找到插件")) @@ -115,17 +112,17 @@ class PluginUpdateAllOperator(operator.CommandOperator): ) -> typing.AsyncGenerator[entities.CommandReturn, None]: try: - plugins = [] - - for key in plugin_host.__plugins__: - plugins.append(key) + plugins = [ + p.plugin_name + for p in self.ap.plugin_mgr.plugins + ] if plugins: yield entities.CommandReturn(text="正在更新插件...") updated = [] try: for plugin_name in plugins: - plugin_host.update_plugin(plugin_name) + await self.ap.plugin_mgr.update_plugin(plugin_name) updated.append(plugin_name) except Exception as e: traceback.print_exc() @@ -157,11 +154,11 @@ class PluginDelOperator(operator.CommandOperator): plugin_name = context.crt_params[0] try: - plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(plugin_name) + plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) - if plugin_path_name is not None: + if plugin_container is not None: yield entities.CommandReturn(text="正在删除插件...") - plugin_host.uninstall_plugin(plugin_name) + await self.ap.plugin_mgr.uninstall_plugin(plugin_name) yield entities.CommandReturn(text="插件删除成功,请重启程序以加载插件") else: yield entities.CommandReturn(error=errors.CommandError("插件删除失败: 未找到插件")) @@ -171,12 +168,15 @@ class PluginDelOperator(operator.CommandOperator): def update_plugin_status(plugin_name: str, new_status: bool, ap: app.Application): - if plugin_name in plugin_host.__plugins__: - plugin_host.__plugins__[plugin_name]['enabled'] = new_status + if ap.plugin_mgr.get_plugin_by_name(plugin_name) is not None: + for plugin in ap.plugin_mgr.plugins: + if plugin.plugin_name == plugin_name: + plugin.enabled = new_status - for func in ap.tool_mgr.all_functions: - if func.name.startswith(plugin_name+'-'): - func.enable = new_status + for func in plugin.content_functions: + func.enable = new_status + + break return True else: diff --git a/pkg/command/operators/update.py b/pkg/command/operators/update.py index db493b6a..524a26dd 100644 --- a/pkg/command/operators/update.py +++ b/pkg/command/operators/update.py @@ -4,7 +4,6 @@ import typing import traceback from .. import operator, entities, cmdmgr, errors -from ...utils import updater @operator.operator_class( @@ -22,7 +21,7 @@ class UpdateCommand(operator.CommandOperator): try: yield entities.CommandReturn(text="正在进行更新...") - if updater.update_all(): + if await self.ap.ver_mgr.update_all(): yield entities.CommandReturn(text="更新完成,请重启程序以应用更新") else: yield entities.CommandReturn(text="当前已是最新版本") diff --git a/pkg/command/operators/version.py b/pkg/command/operators/version.py index c2235800..864826bd 100644 --- a/pkg/command/operators/version.py +++ b/pkg/command/operators/version.py @@ -3,7 +3,6 @@ from __future__ import annotations import typing from .. import operator, cmdmgr, entities, errors -from ...utils import updater @operator.operator_class( @@ -17,10 +16,10 @@ class VersionCommand(operator.CommandOperator): self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - reply_str = f"当前版本: \n{updater.get_current_version_info()}" + reply_str = f"当前版本: \n{await self.ap.ver_mgr.get_current_version_info()}" try: - if updater.is_new_version_available(): + if await self.ap.ver_mgr.is_new_version_available(): reply_str += "\n\n有新版本可用, 使用 !update 更新" except: pass diff --git a/pkg/config/impls/json.py b/pkg/config/impls/json.py index cfc284cb..544f1a85 100644 --- a/pkg/config/impls/json.py +++ b/pkg/config/impls/json.py @@ -26,6 +26,9 @@ class JSONConfigFile(file_model.ConfigFile): async def load(self) -> dict: + if not self.exists(): + await self.create() + with open(self.config_file_name, 'r', encoding='utf-8') as f: cfg = json.load(f) diff --git a/pkg/core/app.py b/pkg/core/app.py index 3768d373..a2783def 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -3,22 +3,22 @@ from __future__ import annotations import logging import asyncio -from ..platform import manager as qqbot_mgr +from ..platform import manager as im_mgr from ..provider.session import sessionmgr as llm_session_mgr from ..provider.requester import modelmgr as llm_model_mgr from ..provider.sysprompt import sysprompt as llm_prompt_mgr from ..provider.tools import toolmgr as llm_tool_mgr from ..config import manager as config_mgr -from ..database import manager as database_mgr -from ..utils.center import v2 as center_mgr +# from ..utils.center import v2 as center_mgr from ..command import cmdmgr -from ..plugin import host as plugin_host +from ..plugin import manager as plugin_mgr from . import pool, controller from ..pipeline import stagemgr +from ..utils import version as version_mgr, proxy as proxy_mgr class Application: - im_mgr: qqbot_mgr.QQBotManager = None + im_mgr: im_mgr.QQBotManager = None cmd_mgr: cmdmgr.CommandManager = None @@ -34,9 +34,9 @@ class Application: tips_mgr: config_mgr.ConfigManager = None - db_mgr: database_mgr.DatabaseManager = None + # ctr_mgr: center_mgr.V2CenterAPI = None - ctr_mgr: center_mgr.V2CenterAPI = None + plugin_mgr: plugin_mgr.PluginManager = None query_pool: pool.QueryPool = None @@ -44,24 +44,29 @@ class Application: stage_mgr: stagemgr.StageManager = None + ver_mgr: version_mgr.VersionManager = None + + proxy_mgr: proxy_mgr.ProxyManager = None + logger: logging.Logger = None def __init__(self): pass async def initialize(self): - plugin_host.initialize_plugins() + pass # 把现有的所有内容函数加到toolmgr里 - for func in plugin_host.__callable_functions__: - self.tool_mgr.register_legacy_function( - name=func['name'], - description=func['description'], - parameters=func['parameters'], - func=plugin_host.__function_inst_map__[func['name']] - ) + # for func in plugin_host.__callable_functions__: + # self.tool_mgr.register_legacy_function( + # name=func['name'], + # description=func['description'], + # parameters=func['parameters'], + # func=plugin_host.__function_inst_map__[func['name']] + # ) async def run(self): + await self.plugin_mgr.load_plugins() tasks = [ asyncio.create_task(self.im_mgr.run()), diff --git a/pkg/core/boot.py b/pkg/core/boot.py index 9bebc526..1332e486 100644 --- a/pkg/core/boot.py +++ b/pkg/core/boot.py @@ -13,17 +13,15 @@ from . import pool from . import controller from ..pipeline import stagemgr from ..audit import identifier -from ..database import manager as db_mgr from ..provider.session import sessionmgr as llm_session_mgr from ..provider.requester import modelmgr as llm_model_mgr from ..provider.sysprompt import sysprompt as llm_prompt_mgr from ..provider.tools import toolmgr as llm_tool_mgr from ..platform import manager as im_mgr from ..command import cmdmgr -from ..plugin import host as plugin_host +from ..plugin import manager as plugin_mgr from ..utils.center import v2 as center_v2 -from ..utils import updater -from ..utils import context +from ..utils import version, proxy use_override = False @@ -58,7 +56,6 @@ async def make_app() -> app.Application: "config.py", "config-template.py" ) - context.set_config_manager(cfg_mgr) cfg = cfg_mgr.data # 检查是否携带了 --override 或 -r 参数 @@ -87,11 +84,20 @@ async def make_app() -> app.Application: ap.query_pool = pool.QueryPool() + proxy_mgr = proxy.ProxyManager(ap) + await proxy_mgr.initialize() + ap.proxy_mgr = proxy_mgr + + ver_mgr = version.VersionManager(ap) + await ver_mgr.initialize() + ap.ver_mgr = ver_mgr + center_v2_api = center_v2.V2CenterAPI( + ap, basic_info={ "host_id": identifier.identifier['host_id'], "instance_id": identifier.identifier['instance_id'], - "semantic_version": updater.get_current_tag(), + "semantic_version": ver_mgr.get_current_version(), "platform": sys.platform, }, runtime_info={ @@ -99,12 +105,7 @@ async def make_app() -> app.Application: "msg_source": cfg['msg_source_adapter'], } ) - ap.ctr_mgr = center_v2_api - - db_mgr_inst = db_mgr.DatabaseManager(ap) - # TODO make it async - db_mgr_inst.initialize_database() - ap.db_mgr = db_mgr_inst + # ap.ctr_mgr = center_v2_api cmd_mgr_inst = cmdmgr.CommandManager(ap) await cmd_mgr_inst.initialize() @@ -138,7 +139,9 @@ async def make_app() -> app.Application: ap.ctrl = ctrl # TODO make it async - plugin_host.load_plugins() + plugin_mgr_inst = plugin_mgr.PluginManager(ap) + await plugin_mgr_inst.initialize() + ap.plugin_mgr = plugin_mgr_inst await ap.initialize() diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 505112ff..b7dccf30 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -2,10 +2,17 @@ from __future__ import annotations import enum import typing +import datetime +import asyncio import pydantic import mirai +from ..provider import entities as llm_entities +from ..provider.requester import entities +from ..provider.sysprompt import entities as sysprompt_entities +from ..provider.tools import entities as tools_entities + class LauncherTypes(enum.Enum): @@ -39,3 +46,43 @@ class Query(pydantic.BaseModel): resp_message_chain: typing.Optional[mirai.MessageChain] = None """回复消息链""" + + +class Conversation(pydantic.BaseModel): + """对话""" + + prompt: sysprompt_entities.Prompt + + messages: list[llm_entities.Message] + + create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + + update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + + use_model: entities.LLMModelInfo + + use_funcs: typing.Optional[list[tools_entities.LLMFunction]] + + +class Session(pydantic.BaseModel): + """会话""" + launcher_type: LauncherTypes + + launcher_id: int + + sender_id: typing.Optional[int] = 0 + + use_prompt_name: typing.Optional[str] = 'default' + + using_conversation: typing.Optional[Conversation] = None + + conversations: typing.Optional[list[Conversation]] = [] + + create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + + update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) + + semaphore: typing.Optional[asyncio.Semaphore] = None + + class Config: + arbitrary_types_allowed = True diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index e088954b..521cc4c0 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -12,7 +12,6 @@ import func_timeout from ..provider import session as openai_session -from ..utils import context import tips as tips_custom from ..platform import adapter as msadapter from .ratelim import ratelim @@ -40,7 +39,7 @@ class QQBotManager: async def initialize(self): await self.ratelimiter.initialize() - config = context.get_config_manager().data + config = self.ap.cfg_mgr.data logging.debug("Use adapter:" + config['msg_source_adapter']) if config['msg_source_adapter'] == 'yirimirai': @@ -106,7 +105,7 @@ class QQBotManager: ) async def send(self, event, msg, check_quote=True, check_at_sender=True): - config = context.get_config_manager().data + config = self.ap.cfg_mgr.data if check_at_sender and config['at_sender']: msg.insert( @@ -134,7 +133,7 @@ class QQBotManager: await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))])) async def notify_admin_message_chain(self, message: mirai.MessageChain): - config = context.get_config_manager().data + config = self.ap.cfg_mgr.data if config['admin_qq'] != 0 and config['admin_qq'] != []: logging.info("通知管理员:{}".format(message)) diff --git a/pkg/plugin/context.py b/pkg/plugin/context.py new file mode 100644 index 00000000..a982232f --- /dev/null +++ b/pkg/plugin/context.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +import typing +import abc +import pydantic + +from . import events +from ..provider.tools import entities as tools_entities +from ..core import app + + +class BasePlugin(metaclass=abc.ABCMeta): + """插件基类""" + + host: APIHost + + +class APIHost: + """QChatGPT API 宿主""" + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + def require_ver( + self, + ge: str, + le: str='v999.999.999', + ) -> bool: + """插件版本要求装饰器 + + Args: + ge (str): 最低版本要求 + le (str, optional): 最高版本要求 + + Returns: + bool: 是否满足要求, False时为无法获取版本号,True时为满足要求,报错为不满足要求 + """ + qchatgpt_version = "" + + try: + qchatgpt_version = self.ap.ver_mgr.get_current_version() # 从updater模块获取版本号 + except: + return False + + if self.ap.ver_mgr.compare_version_str(qchatgpt_version, ge) < 0 or \ + (self.ap.ver_mgr.compare_version_str(qchatgpt_version, le) > 0): + raise Exception("QChatGPT 版本不满足要求,某些功能(可能是由插件提供的)无法正常使用。(要求版本:{}-{},但当前版本:{})".format(ge, le, qchatgpt_version)) + + return True + + +class EventContext: + """事件上下文, 保存此次事件运行的信息""" + + eid = 0 + """事件编号""" + + host: APIHost = None + + event: events.BaseEventModel = None + + __prevent_default__ = False + """是否阻止默认行为""" + + __prevent_postorder__ = False + """是否阻止后续插件的执行""" + + __return_value__ = {} + """ 返回值 + 示例: + { + "example": [ + 'value1', + 'value2', + 3, + 4, + { + 'key1': 'value1', + }, + ['value1', 'value2'] + ] + } + """ + + def add_return(self, key: str, ret): + """添加返回值""" + if key not in self.__return_value__: + self.__return_value__[key] = [] + self.__return_value__[key].append(ret) + + def get_return(self, key: str) -> list: + """获取key的所有返回值""" + if key in self.__return_value__: + return self.__return_value__[key] + return None + + def get_return_value(self, key: str): + """获取key的首个返回值""" + if key in self.__return_value__: + return self.__return_value__[key][0] + return None + + def prevent_default(self): + """阻止默认行为""" + self.__prevent_default__ = True + + def prevent_postorder(self): + """阻止后续插件执行""" + self.__prevent_postorder__ = True + + def is_prevented_default(self): + """是否阻止默认行为""" + return self.__prevent_default__ + + def is_prevented_postorder(self): + """是否阻止后序插件执行""" + return self.__prevent_postorder__ + + def __init__(self, host: APIHost, event: events.BaseEventModel): + + self.eid = EventContext.eid + self.host = host + self.event = event + self.__prevent_default__ = False + self.__prevent_postorder__ = False + self.__return_value__ = {} + EventContext.eid += 1 + + +class RuntimeContainer(pydantic.BaseModel): + """运行时的插件容器 + + 运行期间存储单个插件的信息 + """ + + plugin_name: str + """插件名称""" + + plugin_description: str + """插件描述""" + + plugin_version: str + """插件版本""" + + plugin_author: str + """插件作者""" + + plugin_source: str + """插件源码地址""" + + main_file: str + """插件主文件路径""" + + pkg_path: str + """插件包路径""" + + plugin_class: typing.Type[BasePlugin] = None + """插件类""" + + enabled: typing.Optional[bool] = True + """是否启用""" + + priority: typing.Optional[int] = 0 + """优先级""" + + plugin_inst: typing.Optional[BasePlugin] = None + """插件实例""" + + event_handlers: dict[typing.Type[events.BaseEventModel], typing.Callable[ + [BasePlugin, EventContext], typing.Awaitable[None] + ]] = {} + """事件处理器""" + + content_functions: list[tools_entities.LLMFunction] = [] + """内容函数""" + + class Config: + arbitrary_types_allowed = True + + def to_setting_dict(self): + return { + 'name': self.plugin_name, + 'description': self.plugin_description, + 'version': self.plugin_version, + 'author': self.plugin_author, + 'source': self.plugin_source, + 'main_file': self.main_file, + 'pkg_path': self.pkg_path, + 'priority': self.priority, + 'enabled': self.enabled, + } + + def set_from_setting_dict( + self, + setting: dict + ): + self.plugin_source = setting['source'] + self.priority = setting['priority'] + self.enabled = setting['enabled'] + + for function in self.content_functions: + function.enable = self.enabled diff --git a/pkg/plugin/errors.py b/pkg/plugin/errors.py new file mode 100644 index 00000000..bd6199e3 --- /dev/null +++ b/pkg/plugin/errors.py @@ -0,0 +1,24 @@ +from __future__ import annotations + + +class PluginSystemError(Exception): + + message: str + + def __init__(self, message: str): + self.message = message + + def __str__(self): + return self.message + + +class PluginNotFoundError(PluginSystemError): + + def __init__(self, message: str): + super().__init__(f"未找到插件: {message}") + + +class PluginInstallerError(PluginSystemError): + + def __init__(self, message: str): + super().__init__(f"安装器操作错误: {message}") diff --git a/pkg/plugin/events.py b/pkg/plugin/events.py new file mode 100644 index 00000000..6b60b233 --- /dev/null +++ b/pkg/plugin/events.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import typing + +import pydantic +import mirai + +from . import context +from ..core import entities as core_entities + + +class BaseEventModel(pydantic.BaseModel): + + class Config: + arbitrary_types_allowed = True + + +class PersonMessageReceived(BaseEventModel): + """收到任何私聊消息时""" + + launcher_type: str + """发起对象类型(group/person)""" + + launcher_id: int + """发起对象ID(群号/QQ号)""" + + sender_id: int + """发送者ID(QQ号)""" + + message_chain: mirai.MessageChain + + query: core_entities.Query + """此次请求的上下文""" + + +class GroupMessageReceived(BaseEventModel): + """收到任何群聊消息时""" + + launcher_type: str + + launcher_id: int + + sender_id: int + + message_chain: mirai.MessageChain + + query: core_entities.Query + """此次请求的上下文""" + + +class PersonNormalMessageReceived(BaseEventModel): + """判断为应该处理的私聊普通消息时触发""" + + launcher_type: str + + launcher_id: int + + sender_id: int + + text_message: str + + query: core_entities.Query + """此次请求的上下文""" + + alter: typing.Optional[str] = None + """修改后的消息文本""" + + reply: typing.Optional[list] = None + """回复消息组件列表""" + + +class PersonCommandSent(BaseEventModel): + """判断为应该处理的私聊命令时触发""" + + launcher_type: str + + launcher_id: int + + sender_id: int + + command: str + + params: list[str] + + text_message: str + + is_admin: bool + + query: core_entities.Query + """此次请求的上下文""" + + alter: typing.Optional[str] = None + """修改后的完整命令文本""" + + reply: typing.Optional[list] = None + """回复消息组件列表""" diff --git a/pkg/plugin/host.py b/pkg/plugin/host.py index 11602cfe..6149da62 100644 --- a/pkg/plugin/host.py +++ b/pkg/plugin/host.py @@ -1,578 +1,5 @@ -# 插件管理模块 -import asyncio -import logging -import importlib -import os -import pkgutil -import sys -import shutil -import traceback -import time -import re +from . events import * +from . context import EventContext, APIHost as PluginHost -from ..utils import updater as updater -from ..utils import network as network -from ..utils import context as context -from ..plugin import switch as switch -from ..plugin import settings as settings -from ..platform import adapter as msadapter -from ..plugin import metadata as metadata - -from mirai import Mirai -import requests - -from CallingGPT.session.session import Session - -__plugins__ = {} -"""插件列表 - -示例: -{ - "example": { - "path": "plugins/example/main.py", - "enabled: True, - "name": "example", - "description": "example", - "version": "0.0.1", - "author": "RockChinQ", - "class": , - "hooks": { - "person_message": [ - - ] - }, - "instance": None - } -} -""" - -__plugins_order__ = [] -"""插件顺序""" - -__enable_content_functions__ = True -"""是否启用内容函数""" - -__callable_functions__ = [] -"""供GPT调用的函数结构""" - -__function_inst_map__: dict[str, callable] = {} -"""函数名:实例 映射""" - - -def generate_plugin_order(): - """根据__plugin__生成插件初始顺序,无视是否启用""" - global __plugins_order__ - __plugins_order__ = [] - for plugin_name in __plugins__: - __plugins_order__.append(plugin_name) - - -def iter_plugins(): - """按照顺序迭代插件""" - for plugin_name in __plugins_order__: - if plugin_name not in __plugins__: - continue - yield __plugins__[plugin_name] - - -def iter_plugins_name(): - """迭代插件名""" - for plugin_name in __plugins_order__: - yield plugin_name - - -__current_module_path__ = "" - - -def walk_plugin_path(module, prefix="", path_prefix=""): - global __current_module_path__ - """遍历插件路径""" - for item in pkgutil.iter_modules(module.__path__): - if item.ispkg: - logging.debug("扫描插件包: plugins/{}".format(path_prefix + item.name)) - walk_plugin_path( - __import__(module.__name__ + "." + item.name, fromlist=[""]), - prefix + item.name + ".", - path_prefix + item.name + "/", - ) - else: - try: - logging.debug( - "扫描插件模块: plugins/{}".format(path_prefix + item.name + ".py") - ) - __current_module_path__ = "plugins/" + path_prefix + item.name + ".py" - - importlib.import_module(module.__name__ + "." + item.name) - logging.debug( - "加载模块: plugins/{} 成功".format(path_prefix + item.name + ".py") - ) - except: - logging.error( - "加载模块: plugins/{} 失败: {}".format( - path_prefix + item.name + ".py", sys.exc_info() - ) - ) - traceback.print_exc() - - -def load_plugins(): - """加载插件""" - logging.debug("加载插件") - PluginHost() - walk_plugin_path(__import__("plugins")) - - logging.debug(__plugins__) - - # 加载开关数据 - switch.load_switch() - - # 生成初始顺序 - generate_plugin_order() - # 加载插件顺序 - settings.load_settings() - - logging.debug("registered plugins: {}".format(__plugins__)) - - # 输出已注册的内容函数列表 - logging.debug("registered content functions: {}".format(__callable_functions__)) - logging.debug("function instance map: {}".format(__function_inst_map__)) - - # 迁移插件源地址记录 - metadata.do_plugin_git_repo_migrate() - - -def initialize_plugins(): - """初始化插件""" - logging.debug("初始化插件") - import pkg.plugin.models as models - - successfully_initialized_plugins = [] - - for plugin in iter_plugins(): - # if not plugin['enabled']: - # continue - try: - models.__current_registering_plugin__ = plugin["name"] - plugin["instance"] = plugin["class"](plugin_host=context.get_plugin_host()) - # logging.info("插件 {} 已初始化".format(plugin['name'])) - successfully_initialized_plugins.append(plugin["name"]) - except: - logging.error("插件{}初始化时发生错误: {}".format(plugin["name"], sys.exc_info())) - logging.debug(traceback.format_exc()) - - logging.info("以下插件已初始化: {}".format(", ".join(successfully_initialized_plugins))) - - -def unload_plugins(): - """卸载插件""" - # 不再显式卸载插件,因为当程序结束时,插件的析构函数会被系统执行 - # for plugin in __plugins__.values(): - # if plugin['enabled'] and plugin['instance'] is not None: - # if not hasattr(plugin['instance'], '__del__'): - # logging.warning("插件{}没有定义析构函数".format(plugin['name'])) - # else: - # try: - # plugin['instance'].__del__() - # logging.info("卸载插件: {}".format(plugin['name'])) - # plugin['instance'] = None - # except: - # logging.error("插件{}卸载时发生错误: {}".format(plugin['name'], sys.exc_info())) - - -def get_github_plugin_repo_label(repo_url: str) -> list[str]: - """获取username, repo""" - - # 提取 username/repo , 正则表达式 - repo = re.findall( - r"(?:https?://github\.com/|git@github\.com:)([^/]+/[^/]+?)(?:\.git|/|$)", - repo_url, - ) - - if len(repo) > 0: # github - return repo[0].split("/") - else: - return None - - -def download_plugin_source_code(repo_url: str, target_path: str) -> str: - """下载插件源码""" - # 检查源类型 - - # 提取 username/repo , 正则表达式 - repo = get_github_plugin_repo_label(repo_url) - - target_path += repo[1] - - if repo is not None: # github - logging.info("从 GitHub 下载插件源码...") - - zipball_url = f"https://api.github.com/repos/{'/'.join(repo)}/zipball/HEAD" - - zip_resp = requests.get( - url=zipball_url, proxies=network.wrapper_proxies(), stream=True - ) - - if zip_resp.status_code != 200: - raise Exception("下载源码失败: {}".format(zip_resp.text)) - - if os.path.exists("temp/" + target_path): - shutil.rmtree("temp/" + target_path) - - if os.path.exists(target_path): - shutil.rmtree(target_path) - - os.makedirs("temp/" + target_path) - - with open("temp/" + target_path + "/source.zip", "wb") as f: - for chunk in zip_resp.iter_content(chunk_size=1024): - if chunk: - f.write(chunk) - - logging.info("下载完成, 解压...") - import zipfile - - with zipfile.ZipFile("temp/" + target_path + "/source.zip", "r") as zip_ref: - zip_ref.extractall("temp/" + target_path) - os.remove("temp/" + target_path + "/source.zip") - - # 目标是 username-repo-hash , 用正则表达式提取完整的文件夹名,复制到 plugins/repo - import glob - - # 获取解压后的文件夹名 - unzip_dir = glob.glob("temp/" + target_path + "/*")[0] - - # 复制到 plugins/repo - shutil.copytree(unzip_dir, target_path + "/") - - # 删除解压后的文件夹 - shutil.rmtree(unzip_dir) - - logging.info("解压完成") - else: - raise Exception("暂不支持的源类型,请使用 GitHub 仓库发行插件。") - - return repo[1] - - -def check_requirements(path: str): - # 检查此目录是否包含requirements.txt - if os.path.exists(path + "/requirements.txt"): - logging.info("检测到requirements.txt,正在安装依赖") - import pkg.utils.pkgmgr - - pkg.utils.pkgmgr.install_requirements(path + "/requirements.txt") - - import pkg.utils.log as log - - log.reset_logging() - - -def install_plugin(repo_url: str): - """安装插件,从git储存库获取并解决依赖""" - - repo_label = download_plugin_source_code(repo_url, "plugins/") - - check_requirements("plugins/" + repo_label) - - metadata.set_plugin_metadata(repo_label, repo_url, int(time.time()), "HEAD") - - # 上报安装记录 - context.get_center_v2_api().plugin.post_install_record( - plugin={ - "name": "unknown", - "remote": repo_url, - "author": "unknown", - "version": "HEAD", - } - ) - - -def uninstall_plugin(plugin_name: str) -> str: - """卸载插件""" - if plugin_name not in __plugins__: - raise Exception("插件不存在") - - plugin_info = get_plugin_info_for_audit(plugin_name) - - # 获取文件夹路径 - plugin_path = __plugins__[plugin_name]["path"].replace("\\", "/") - - # 剪切路径为plugins/插件名 - plugin_path = plugin_path.split("plugins/")[1].split("/")[0] - - # 删除文件夹 - shutil.rmtree("plugins/" + plugin_path) - - # 上报卸载记录 - context.get_center_v2_api().plugin.post_remove_record( - plugin=plugin_info - ) - - return "plugins/" + plugin_path - - -def update_plugin(plugin_name: str): - """更新插件""" - # 检查是否有远程地址记录 - plugin_path_name = get_plugin_path_name_by_plugin_name(plugin_name) - - meta = metadata.get_plugin_metadata(plugin_path_name) - - if meta == {}: - raise Exception("没有此插件元数据信息,无法更新") - - old_plugin_info = get_plugin_info_for_audit(plugin_name) - - context.get_center_v2_api().plugin.post_update_record( - plugin=old_plugin_info, - old_version=old_plugin_info['version'], - new_version='HEAD', - ) - - remote_url = meta["source"] - if ( - remote_url == "https://github.com/RockChinQ/QChatGPT" - or remote_url == "https://gitee.com/RockChin/QChatGPT" - or remote_url == "" - or remote_url is None - or remote_url == "http://github.com/RockChinQ/QChatGPT" - or remote_url == "http://gitee.com/RockChin/QChatGPT" - ): - raise Exception("插件没有远程地址记录,无法更新") - - # 重新安装插件 - logging.info("正在重新安装插件以进行更新...") - - install_plugin(remote_url) - - -def get_plugin_name_by_path_name(plugin_path_name: str) -> str: - for k, v in __plugins__.items(): - if v["path"] == "plugins/" + plugin_path_name + "/main.py": - return k - return None - - -def get_plugin_path_name_by_plugin_name(plugin_name: str) -> str: - if plugin_name not in __plugins__: - return None - - plugin_main_module_path = __plugins__[plugin_name]["path"] - - plugin_main_module_path = plugin_main_module_path.replace("\\", "/") - - spt = plugin_main_module_path.split("/") - - return spt[1] - - -def get_plugin_info_for_audit(plugin_name: str) -> dict: - """获取插件信息""" - if plugin_name not in __plugins__: - return {} - plugin = __plugins__[plugin_name] - - name = plugin["name"] - meta = metadata.get_plugin_metadata(get_plugin_path_name_by_plugin_name(name)) - remote = meta["source"] if meta != {} else "" - author = plugin["author"] - version = plugin["version"] - - return { - "name": name, - "remote": remote, - "author": author, - "version": version, - } - - -class EventContext: - """事件上下文""" - - eid = 0 - """事件编号""" - - name = "" - - __prevent_default__ = False - """是否阻止默认行为""" - - __prevent_postorder__ = False - """是否阻止后续插件的执行""" - - __return_value__ = {} - """ 返回值 - 示例: - { - "example": [ - 'value1', - 'value2', - 3, - 4, - { - 'key1': 'value1', - }, - ['value1', 'value2'] - ] - } - """ - - def add_return(self, key: str, ret): - """添加返回值""" - if key not in self.__return_value__: - self.__return_value__[key] = [] - self.__return_value__[key].append(ret) - - def get_return(self, key: str) -> list: - """获取key的所有返回值""" - if key in self.__return_value__: - return self.__return_value__[key] - return None - - def get_return_value(self, key: str): - """获取key的首个返回值""" - if key in self.__return_value__: - return self.__return_value__[key][0] - return None - - def prevent_default(self): - """阻止默认行为""" - self.__prevent_default__ = True - - def prevent_postorder(self): - """阻止后续插件执行""" - self.__prevent_postorder__ = True - - def is_prevented_default(self): - """是否阻止默认行为""" - return self.__prevent_default__ - - def is_prevented_postorder(self): - """是否阻止后序插件执行""" - return self.__prevent_postorder__ - - def __init__(self, name: str): - self.name = name - self.eid = EventContext.eid - self.__prevent_default__ = False - self.__prevent_postorder__ = False - self.__return_value__ = {} - EventContext.eid += 1 - - -def emit(event_name: str, **kwargs) -> EventContext: - """触发事件""" - import pkg.utils.context as context - - if context.get_plugin_host() is None: - return None - return context.get_plugin_host().emit(event_name, **kwargs) - - -class PluginHost: - """插件宿主""" - - def __init__(self): - """初始化插件宿主""" - context.set_plugin_host(self) - self.calling_gpt_session = Session([]) - - def get_runtime_context(self) -> context: - """获取运行时上下文(pkg.utils.context模块的对象) - - 此上下文用于和主程序其他模块交互(数据库、QQ机器人、OpenAI接口等) - 详见pkg.utils.context模块 - 其中的context变量保存了其他重要模块的类对象,可以使用这些对象进行交互 - """ - return context - - def get_bot(self) -> Mirai: - """获取机器人对象""" - return context.get_qqbot_manager().bot - - def get_bot_adapter(self) -> msadapter.MessageSourceAdapter: - """获取消息源适配器""" - return context.get_qqbot_manager().adapter - - def send_person_message(self, person, message): - """发送私聊消息""" - self.get_bot_adapter().send_message("person", person, message) - - def send_group_message(self, group, message): - """发送群消息""" - self.get_bot_adapter().send_message("group", group, message) - - def notify_admin(self, message): - """通知管理员""" - context.get_qqbot_manager().notify_admin(message) - - def emit(self, event_name: str, **kwargs) -> EventContext: - """触发事件""" - import json - - event_context = EventContext(event_name) - logging.debug("触发事件: {} ({})".format(event_name, event_context.eid)) - - emitted_plugins = [] - for plugin in iter_plugins(): - if not plugin["enabled"]: - continue - - # if plugin['instance'] is None: - # # 从关闭状态切到开启状态之后,重新加载插件 - # try: - # plugin['instance'] = plugin["class"](plugin_host=self) - # logging.info("插件 {} 已初始化".format(plugin['name'])) - # except: - # logging.error("插件 {} 初始化时发生错误: {}".format(plugin['name'], sys.exc_info())) - # continue - - if "hooks" not in plugin or event_name not in plugin["hooks"]: - continue - - emitted_plugins.append(plugin['name']) - - hooks = [] - if event_name in plugin["hooks"]: - hooks = plugin["hooks"][event_name] - for hook in hooks: - try: - already_prevented_default = event_context.is_prevented_default() - - kwargs["host"] = context.get_plugin_host() - kwargs["event"] = event_context - - hook(plugin["instance"], **kwargs) - - if ( - event_context.is_prevented_default() - and not already_prevented_default - ): - logging.debug( - "插件 {} 已要求阻止事件 {} 的默认行为".format(plugin["name"], event_name) - ) - - except Exception as e: - logging.error("插件{}响应事件{}时发生错误".format(plugin["name"], event_name)) - logging.error(traceback.format_exc()) - - # print("done:{}".format(plugin['name'])) - if event_context.is_prevented_postorder(): - logging.debug("插件 {} 阻止了后序插件的执行".format(plugin["name"])) - break - - logging.debug( - "事件 {} ({}) 处理完毕,返回值: {}".format( - event_name, event_context.eid, event_context.__return_value__ - ) - ) - - if len(emitted_plugins) > 0: - plugins_info = [get_plugin_info_for_audit(p) for p in emitted_plugins] - - context.get_center_v2_api().usage.post_event_record( - plugins=plugins_info, - event_name=event_name, - ) - - return event_context +def emit(*args, **kwargs): + print('插件调用了已弃用的函数 pkg.plugin.host.emit()') \ No newline at end of file diff --git a/pkg/plugin/installer.py b/pkg/plugin/installer.py new file mode 100644 index 00000000..6a089438 --- /dev/null +++ b/pkg/plugin/installer.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import typing +import abc + +from ..core import app + + +class PluginInstaller(metaclass=abc.ABCMeta): + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def install_plugin( + self, + plugin_source: str, + ): + """安装插件 + """ + raise NotImplementedError + + @abc.abstractmethod + async def uninstall_plugin( + self, + plugin_name: str, + ): + """卸载插件 + """ + raise NotImplementedError + + @abc.abstractmethod + async def update_plugin( + self, + plugin_name: str, + plugin_source: str=None, + ): + """更新插件 + """ + raise NotImplementedError diff --git a/pkg/provider/api/__init__.py b/pkg/plugin/installers/__init__.py similarity index 100% rename from pkg/provider/api/__init__.py rename to pkg/plugin/installers/__init__.py diff --git a/pkg/plugin/installers/github.py b/pkg/plugin/installers/github.py new file mode 100644 index 00000000..8908f181 --- /dev/null +++ b/pkg/plugin/installers/github.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import re +import os +import shutil +import zipfile + +import requests + +from .. import installer, errors +from ...utils import pkgmgr + + +class GitHubRepoInstaller(installer.PluginInstaller): + + def get_github_plugin_repo_label(self, repo_url: str) -> list[str]: + """获取username, repo""" + + # 提取 username/repo , 正则表达式 + repo = re.findall( + r"(?:https?://github\.com/|git@github\.com:)([^/]+/[^/]+?)(?:\.git|/|$)", + repo_url, + ) + + if len(repo) > 0: # github + return repo[0].split("/") + else: + return None + + async def download_plugin_source_code(self, repo_url: str, target_path: str) -> str: + """下载插件源码""" + # 检查源类型 + + # 提取 username/repo , 正则表达式 + repo = self.get_github_plugin_repo_label(repo_url) + + target_path += repo[1] + + if repo is not None: # github + self.ap.logger.debug("正在下载源码...") + + zipball_url = f"https://api.github.com/repos/{'/'.join(repo)}/zipball/HEAD" + + zip_resp = requests.get( + url=zipball_url, proxies=self.ap.proxy_mgr.get_forward_proxies(), stream=True + ) + + if zip_resp.status_code != 200: + raise Exception("下载源码失败: {}".format(zip_resp.text)) + + if os.path.exists("temp/" + target_path): + shutil.rmtree("temp/" + target_path) + + if os.path.exists(target_path): + shutil.rmtree(target_path) + + os.makedirs("temp/" + target_path) + + with open("temp/" + target_path + "/source.zip", "wb") as f: + for chunk in zip_resp.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + + self.ap.logger.debug("解压中...") + + with zipfile.ZipFile("temp/" + target_path + "/source.zip", "r") as zip_ref: + zip_ref.extractall("temp/" + target_path) + os.remove("temp/" + target_path + "/source.zip") + + # 目标是 username-repo-hash , 用正则表达式提取完整的文件夹名,复制到 plugins/repo + import glob + + # 获取解压后的文件夹名 + unzip_dir = glob.glob("temp/" + target_path + "/*")[0] + + # 复制到 plugins/repo + shutil.copytree(unzip_dir, target_path + "/") + + # 删除解压后的文件夹 + shutil.rmtree(unzip_dir) + + self.ap.logger.debug("源码下载完成。") + else: + raise errors.PluginInstallerError('仅支持GitHub仓库地址') + + return repo[1] + + async def install_requirements(self, path: str): + if os.path.exists(path + "/requirements.txt"): + pkgmgr.install_requirements(path + "/requirements.txt") + + async def install_plugin( + self, + plugin_source: str, + ): + """安装插件 + """ + repo_label = await self.download_plugin_source_code(plugin_source, "plugins/") + + await self.install_requirements("plugins/" + repo_label) + + await self.ap.plugin_mgr.setting.record_installed_plugin_source( + "plugins/"+repo_label+'/', plugin_source + ) + + async def uninstall_plugin( + self, + plugin_name: str, + ): + """卸载插件 + """ + plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) + + if plugin_container is None: + raise errors.PluginInstallerError('插件不存在或未成功加载') + else: + shutil.rmtree(plugin_container.pkg_path) + + async def update_plugin( + self, + plugin_name: str, + plugin_source: str=None, + ): + """更新插件 + """ + plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) + + if plugin_container is None: + raise errors.PluginInstallerError('插件不存在或未成功加载') + else: + if plugin_container.plugin_source: + plugin_source = plugin_container.plugin_source + + await self.install_plugin(plugin_source) + + else: + raise errors.PluginInstallerError('插件无源码信息,无法更新') diff --git a/pkg/plugin/loader.py b/pkg/plugin/loader.py new file mode 100644 index 00000000..d74bcde7 --- /dev/null +++ b/pkg/plugin/loader.py @@ -0,0 +1,25 @@ +from __future__ import annotations +from abc import ABCMeta + +import typing +import abc + +from ..core import app +from . import context, events + + +class PluginLoader(metaclass=abc.ABCMeta): + """插件加载器""" + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def load_plugins(self) -> list[context.RuntimeContainer]: + pass + diff --git a/pkg/plugin/loaders/__init__.py b/pkg/plugin/loaders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/plugin/loaders/legacy.py b/pkg/plugin/loaders/legacy.py new file mode 100644 index 00000000..1ba0e54f --- /dev/null +++ b/pkg/plugin/loaders/legacy.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import typing +import pkgutil +import importlib +import traceback + +from CallingGPT.entities.namespace import get_func_schema + +from .. import loader, events, context, models, host +from ...core import entities as core_entities +from ...provider.tools import entities as tools_entities + + +class PluginLoader(loader.PluginLoader): + """加载 plugins/ 目录下的插件""" + + _current_pkg_path = '' + + _current_module_path = '' + + _current_container: context.RuntimeContainer = None + + containers: list[context.RuntimeContainer] = [] + + async def initialize(self): + """初始化""" + setattr(models, 'register', self.register) + setattr(models, 'on', self.on) + setattr(models, 'func', self.func) + + def register( + self, + name: str, + description: str, + version: str, + author: str + ) -> typing.Callable[[typing.Type[context.BasePlugin]], typing.Type[context.BasePlugin]]: + self.ap.logger.debug(f'注册插件 {name} {version} by {author}') + container = context.RuntimeContainer( + plugin_name=name, + plugin_description=description, + plugin_version=version, + plugin_author=author, + plugin_source='', + pkg_path=self._current_pkg_path, + main_file=self._current_module_path, + event_handlers={}, + content_functions=[], + ) + + self._current_container = container + + def wrapper(cls: context.BasePlugin) -> typing.Type[context.BasePlugin]: + container.plugin_class = cls + return cls + + return wrapper + + def on( + self, + event: typing.Type[events.BaseEventModel] + ) -> typing.Callable[[typing.Callable], typing.Callable]: + """注册过时的事件处理器""" + self.ap.logger.debug(f'注册事件处理器 {event.__name__}') + def wrapper(func: typing.Callable) -> typing.Callable: + + async def handler(plugin: context.BasePlugin, ctx: context.EventContext) -> None: + args = { + 'host': ctx.host, + 'event': ctx, + } + + # 把 ctx.event 所有的属性都放到 args 里 + for k, v in ctx.event.dict().items(): + args[k] = v + + await func(plugin, **args) + + self._current_container.event_handlers[event] = handler + + return func + + return wrapper + + def func( + self, + name: str=None, + ) -> typing.Callable: + """注册过时的内容函数""" + self.ap.logger.debug(f'注册内容函数 {name}') + def wrapper(func: typing.Callable) -> typing.Callable: + + function_schema = get_func_schema(func) + function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name) + + async def handler( + query: core_entities.Query, + *args, + **kwargs + ): + return func(*args, **kwargs) + + llm_function = tools_entities.LLMFunction( + name=function_name, + human_desc='', + description=function_schema['description'], + enable=True, + parameters=function_schema['parameters'], + func=handler, + ) + + self._current_container.content_functions.append(llm_function) + + return func + + return wrapper + + async def _walk_plugin_path( + self, + module, + prefix='', + path_prefix='' + ): + """遍历插件路径 + """ + for item in pkgutil.iter_modules(module.__path__): + if item.ispkg: + await self._walk_plugin_path( + __import__(module.__name__ + "." + item.name, fromlist=[""]), + prefix + item.name + ".", + path_prefix + item.name + "/", + ) + else: + try: + self._current_pkg_path = "plugins/" + path_prefix + self._current_module_path = "plugins/" + path_prefix + item.name + ".py" + + self._current_container = None + + importlib.import_module(module.__name__ + "." + item.name) + + if self._current_container is not None: + self.containers.append(self._current_container) + self.ap.logger.debug(f'插件 {self._current_container} 已加载') + except: + self.ap.logger.error(f'加载插件模块 {prefix + item.name} 时发生错误') + traceback.print_exc() + + async def load_plugins(self) -> list[context.RuntimeContainer]: + """加载插件 + """ + await self._walk_plugin_path(__import__("plugins", fromlist=[""])) + + return self.containers diff --git a/pkg/plugin/manager.py b/pkg/plugin/manager.py new file mode 100644 index 00000000..6591839d --- /dev/null +++ b/pkg/plugin/manager.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import typing + +from ..core import app +from . import context, loader, events, installer, setting, models +from .loaders import legacy +from .installers import github + + +class PluginManager: + + ap: app.Application + + loader: loader.PluginLoader + + installer: installer.PluginInstaller + + setting: setting.SettingManager + + api_host: context.APIHost + + plugins: list[context.RuntimeContainer] + + def __init__(self, ap: app.Application): + self.ap = ap + self.loader = legacy.PluginLoader(ap) + self.installer = github.GitHubRepoInstaller(ap) + self.setting = setting.SettingManager(ap) + self.api_host = context.APIHost(ap) + self.plugins = [] + + async def initialize(self): + await self.loader.initialize() + await self.installer.initialize() + await self.setting.initialize() + await self.api_host.initialize() + + setattr(models, 'require_ver', self.api_host.require_ver) + + async def load_plugins(self): + self.plugins = await self.loader.load_plugins() + + await self.setting.sync_setting(self.plugins) + + # 按优先级倒序 + self.plugins.sort(key=lambda x: x.priority, reverse=True) + + async def initialize_plugins(self): + pass + + async def install_plugin( + self, + plugin_source: str, + ): + """安装插件 + """ + await self.installer.install_plugin(plugin_source) + + async def uninstall_plugin( + self, + plugin_name: str, + ): + """卸载插件 + """ + await self.installer.uninstall_plugin(plugin_name) + + async def update_plugin( + self, + plugin_name: str, + plugin_source: str=None, + ): + """更新插件 + """ + await self.installer.update_plugin(plugin_name, plugin_source) + + def get_plugin_by_name(self, plugin_name: str) -> context.RuntimeContainer: + """通过插件名获取插件 + """ + for plugin in self.plugins: + if plugin.plugin_name == plugin_name: + return plugin + return None + + async def emit_event(self, event: events.BaseEventModel) -> context.EventContext: + """触发事件 + """ + + ctx = context.EventContext( + host=self.api_host, + event=event + ) + + for plugin in self.plugins: + if plugin.enabled: + if event.__class__ in plugin.event_handlers: + try: + await plugin.event_handlers[event.__class__]( + plugin.plugin_inst, + ctx + ) + except Exception as e: + self.ap.logger.error(f'插件 {plugin.plugin_name} 触发事件 {event.__class__.__name__} 时发生错误: {e}') + self.ap.logger.exception(e) + + if ctx.is_prevented_postorder(): + self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行') + break + + self.ap.logger.debug(f'事件 {event.__class__.__name__}({ctx.eid}) 处理完成,返回值 {ctx.__return_value__}') + + return ctx \ No newline at end of file diff --git a/pkg/plugin/metadata.py b/pkg/plugin/metadata.py deleted file mode 100644 index 51de742e..00000000 --- a/pkg/plugin/metadata.py +++ /dev/null @@ -1,87 +0,0 @@ -import os -import shutil -import json -import time - -import dulwich.errors as dulwich_err - -from ..utils import updater - - -def read_metadata_file() -> dict: - # 读取 plugins/metadata.json 文件 - if not os.path.exists('plugins/metadata.json'): - return {} - with open('plugins/metadata.json', 'r') as f: - return json.load(f) - - -def write_metadata_file(metadata: dict): - if not os.path.exists('plugins'): - os.mkdir('plugins') - - with open('plugins/metadata.json', 'w') as f: - json.dump(metadata, f, indent=4, ensure_ascii=False) - - -def do_plugin_git_repo_migrate(): - # 仅在 plugins/metadata.json 不存在时执行 - if os.path.exists('plugins/metadata.json'): - return - - metadata = read_metadata_file() - - # 遍历 plugins 下所有目录,获取目录的git远程地址 - for plugin_name in os.listdir('plugins'): - plugin_path = os.path.join('plugins', plugin_name) - if not os.path.isdir(plugin_path): - continue - - remote_url = None - try: - remote_url = updater.get_remote_url(plugin_path) - except dulwich_err.NotGitRepository: - continue - if remote_url == "https://github.com/RockChinQ/QChatGPT" or remote_url == "https://gitee.com/RockChin/QChatGPT" \ - or remote_url == "" or remote_url is None or remote_url == "http://github.com/RockChinQ/QChatGPT" or remote_url == "http://gitee.com/RockChin/QChatGPT": - continue - - from . import host - - if plugin_name not in metadata: - metadata[plugin_name] = { - 'source': remote_url, - 'install_timestamp': int(time.time()), - 'ref': 'HEAD', - } - - write_metadata_file(metadata) - - -def set_plugin_metadata( - plugin_name: str, - source: str, - install_timestamp: int, - ref: str, -): - metadata = read_metadata_file() - metadata[plugin_name] = { - 'source': source, - 'install_timestamp': install_timestamp, - 'ref': ref, - } - write_metadata_file(metadata) - - -def remove_plugin_metadata(plugin_name: str): - metadata = read_metadata_file() - if plugin_name in metadata: - del metadata[plugin_name] - write_metadata_file(metadata) - - -def get_plugin_metadata(plugin_name: str) -> dict: - metadata = read_metadata_file() - if plugin_name in metadata: - return metadata[plugin_name] - return {} \ No newline at end of file diff --git a/pkg/plugin/models.py b/pkg/plugin/models.py index a606612d..19580309 100644 --- a/pkg/plugin/models.py +++ b/pkg/plugin/models.py @@ -1,299 +1 @@ -import logging - -from ..plugin import host -from ..utils import context - -PersonMessageReceived = "person_message_received" -"""收到私聊消息时,在判断是否应该响应前触发 - kwargs: - launcher_type: str 发起对象类型(group/person) - launcher_id: int 发起对象ID(群号/QQ号) - sender_id: int 发送者ID(QQ号) - message_chain: mirai.models.message.MessageChain 消息链 -""" - -GroupMessageReceived = "group_message_received" -"""收到群聊消息时,在判断是否应该响应前触发(所有群消息) - kwargs: - launcher_type: str 发起对象类型(group/person) - launcher_id: int 发起对象ID(群号/QQ号) - sender_id: int 发送者ID(QQ号) - message_chain: mirai.models.message.MessageChain 消息链 -""" - -PersonNormalMessageReceived = "person_normal_message_received" -"""判断为应该处理的私聊普通消息时触发 - kwargs: - launcher_type: str 发起对象类型(group/person) - launcher_id: int 发起对象ID(群号/QQ号) - sender_id: int 发送者ID(QQ号) - text_message: str 消息文本 - - returns (optional): - alter: str 修改后的消息文本 - reply: list 回复消息组件列表 -""" - -PersonCommandSent = "person_command_sent" -"""判断为应该处理的私聊命令时触发 - kwargs: - launcher_type: str 发起对象类型(group/person) - launcher_id: int 发起对象ID(群号/QQ号) - sender_id: int 发送者ID(QQ号) - command: str 命令 - params: list[str] 参数列表 - text_message: str 完整命令文本 - is_admin: bool 是否为管理员 - - returns (optional): - alter: str 修改后的完整命令文本 - reply: list 回复消息组件列表 -""" - -GroupNormalMessageReceived = "group_normal_message_received" -"""判断为应该处理的群聊普通消息时触发 - kwargs: - launcher_type: str 发起对象类型(group/person) - launcher_id: int 发起对象ID(群号/QQ号) - sender_id: int 发送者ID(QQ号) - text_message: str 消息文本 - - returns (optional): - alter: str 修改后的消息文本 - reply: list 回复消息组件列表 -""" - -GroupCommandSent = "group_command_sent" -"""判断为应该处理的群聊命令时触发 - kwargs: - launcher_type: str 发起对象类型(group/person) - launcher_id: int 发起对象ID(群号/QQ号) - sender_id: int 发送者ID(QQ号) - command: str 命令 - params: list[str] 参数列表 - text_message: str 完整命令文本 - is_admin: bool 是否为管理员 - - returns (optional): - alter: str 修改后的完整命令文本 - reply: list 回复消息组件列表 -""" - -NormalMessageResponded = "normal_message_responded" -"""获取到对普通消息的文字响应时触发 - kwargs: - launcher_type: str 发起对象类型(group/person) - launcher_id: int 发起对象ID(群号/QQ号) - sender_id: int 发送者ID(QQ号) - session: pkg.openai.session.Session 会话对象 - prefix: str 回复文字消息的前缀 - response_text: str 响应文本 - finish_reason: str 响应结束原因 - funcs_called: list[str] 此次响应中调用的函数列表 - - returns (optional): - prefix: str 修改后的回复文字消息的前缀 - reply: list 替换回复消息组件列表 -""" - -SessionFirstMessageReceived = "session_first_message_received" -"""会话被第一次交互时触发 - kwargs: - session_name: str 会话名称(_) - session: pkg.openai.session.Session 会话对象 - default_prompt: str 预设值 -""" - -SessionExplicitReset = "session_reset" -"""会话被用户手动重置时触发,此事件不支持阻止默认行为 - kwargs: - session_name: str 会话名称(_) - session: pkg.openai.session.Session 会话对象 -""" - -SessionExpired = "session_expired" -"""会话过期时触发 - kwargs: - session_name: str 会话名称(_) - session: pkg.openai.session.Session 会话对象 - session_expire_time: int 已设置的会话过期时间(秒) -""" - -KeyExceeded = "key_exceeded" -"""api-key超额时触发 - kwargs: - key_name: str 超额的api-key名称 - usage: dict 超额的api-key使用情况 - exceeded_keys: list[str] 超额的api-key列表 -""" - -KeySwitched = "key_switched" -"""api-key超额切换成功时触发,此事件不支持阻止默认行为 - kwargs: - key_name: str 切换成功的api-key名称 - key_list: list[str] api-key列表 -""" - -PromptPreProcessing = "prompt_pre_processing" -"""每回合调用接口前对prompt进行预处理时触发,此事件不支持阻止默认行为 - kwargs: - session_name: str 会话名称(_) - default_prompt: list 此session使用的情景预设内容 - prompt: list 此session现有的prompt内容 - text_message: str 用户发送的消息文本 - - returns (optional): - default_prompt: list 修改后的情景预设内容 - prompt: list 修改后的prompt内容 - text_message: str 修改后的消息文本 -""" - - -def on(*args, **kwargs): - """注册事件监听器 - """ - return Plugin.on(*args, **kwargs) - -def func(*args, **kwargs): - """注册内容函数,声明此函数为一个内容函数,在对话中将发送此函数给GPT以供其调用 - 此函数可以具有任意的参数,但必须按照[此文档](https://github.com/RockChinQ/CallingGPT/wiki/1.-Function-Format#function-format) - 所述的格式编写函数的docstring。 - 此功能仅支持在使用gpt-3.5或gpt-4系列模型时使用。 - """ - return Plugin.func(*args, **kwargs) - - -__current_registering_plugin__ = "" - - -def require_ver(ge: str, le: str="v999.9.9") -> bool: - """插件版本要求装饰器 - - Args: - ge (str): 最低版本要求 - le (str, optional): 最高版本要求 - - Returns: - bool: 是否满足要求, False时为无法获取版本号,True时为满足要求,报错为不满足要求 - """ - qchatgpt_version = "" - - from pkg.utils.updater import get_current_tag, compare_version_str - - try: - qchatgpt_version = get_current_tag() # 从updater模块获取版本号 - except: - return False - - if compare_version_str(qchatgpt_version, ge) < 0 or \ - (compare_version_str(qchatgpt_version, le) > 0): - raise Exception("QChatGPT 版本不满足要求,某些功能(可能是由插件提供的)无法正常使用。(要求版本:{}-{},但当前版本:{})".format(ge, le, qchatgpt_version)) - - return True - - -class Plugin: - """插件基类""" - - host: host.PluginHost - """插件宿主,提供插件的一些基础功能""" - - @classmethod - def on(cls, event): - """事件处理器装饰器 - - :param - event: 事件类型 - :return: - None - """ - global __current_registering_plugin__ - - def wrapper(func): - plugin_hooks = host.__plugins__[__current_registering_plugin__]["hooks"] - - if event not in plugin_hooks: - plugin_hooks[event] = [] - plugin_hooks[event].append(func) - - # print("registering hook: p='{}', e='{}', f={}".format(__current_registering_plugin__, event, func)) - - host.__plugins__[__current_registering_plugin__]["hooks"] = plugin_hooks - - return func - - return wrapper - - @classmethod - def func(cls, name: str=None): - """内容函数装饰器 - """ - global __current_registering_plugin__ - from CallingGPT.entities.namespace import get_func_schema - - def wrapper(func): - - function_schema = get_func_schema(func) - function_schema['name'] = __current_registering_plugin__ + '-' + (func.__name__ if name is None else name) - - function_schema['enabled'] = True - - host.__function_inst_map__[function_schema['name']] = function_schema['function'] - - del function_schema['function'] - - # logging.debug("registering content function: p='{}', f='{}', s={}".format(__current_registering_plugin__, func, function_schema)) - - host.__callable_functions__.append( - function_schema - ) - - return func - - return wrapper - - -def register(name: str, description: str, version: str, author: str): - """注册插件, 此函数作为装饰器使用 - - Args: - name (str): 插件名称 - description (str): 插件描述 - version (str): 插件版本 - author (str): 插件作者 - - Returns: - None - """ - global __current_registering_plugin__ - - __current_registering_plugin__ = name - # print("registering plugin: n='{}', d='{}', v={}, a='{}'".format(name, description, version, author)) - host.__plugins__[name] = { - "name": name, - "description": description, - "version": version, - "author": author, - "hooks": {}, - "path": host.__current_module_path__, - "enabled": True, - "instance": None, - } - - def wrapper(cls: Plugin): - cls.name = name - cls.description = description - cls.version = version - cls.author = author - cls.host = context.get_plugin_host() - cls.enabled = True - cls.path = host.__current_module_path__ - - # 存到插件列表 - host.__plugins__[name]["class"] = cls - - logging.info("插件注册完成: n='{}', d='{}', v={}, a='{}' ({})".format(name, description, version, author, cls)) - - return cls - - return wrapper +from .context import BasePlugin as Plugin \ No newline at end of file diff --git a/pkg/plugin/setting.py b/pkg/plugin/setting.py new file mode 100644 index 00000000..baa1c4d0 --- /dev/null +++ b/pkg/plugin/setting.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from ..core import app +from ..config import manager as cfg_mgr +from . import context + + +class SettingManager: + + ap: app.Application + + settings: cfg_mgr.ConfigManager + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + self.settings = await cfg_mgr.load_json_config( + 'plugins/plugins.json', + 'res/templates/plugin-setting-template.json' + ) + + async def sync_setting( + self, + plugin_containers: list[context.RuntimeContainer], + ): + """同步设置 + """ + + not_matched_source_record = [] + + for value in self.settings.data['plugins']: + + if 'name' not in value: # 只有远程地址的,应用到pkg_path相同的插件容器上 + matched = False + + for plugin_container in plugin_containers: + if plugin_container.pkg_path == value['pkg_path']: + matched = True + + plugin_container.plugin_source = value['source'] + break + + if not matched: + not_matched_source_record.append(value) + else: # 正常的插件设置 + for plugin_container in plugin_containers: + if plugin_container.plugin_name == value['name']: + plugin_container.set_from_setting_dict(value) + + self.settings.data = { + 'plugins': [ + p.to_setting_dict() + for p in plugin_containers + ] + } + + self.settings.data['plugins'].extend(not_matched_source_record) + + await self.settings.dump_config() + + async def record_installed_plugin_source( + self, + pkg_path: str, + source: str + ): + found = False + + for value in self.settings.data['plugins']: + if value['pkg_path'] == pkg_path: + value['source'] = source + found = True + break + + if not found: + + self.settings.data['plugins'].append( + { + 'pkg_path': pkg_path, + 'source': source + } + ) + await self.settings.dump_config() \ No newline at end of file diff --git a/pkg/plugin/settings.py b/pkg/plugin/settings.py deleted file mode 100644 index 6824906a..00000000 --- a/pkg/plugin/settings.py +++ /dev/null @@ -1,103 +0,0 @@ -import json -import os - -import logging - -from ..plugin import host - -def wrapper_dict_from_runtime_context() -> dict: - """从变量中包装settings.json的数据字典""" - settings = { - "order": [], - "functions": { - "enabled": host.__enable_content_functions__ - } - } - - for plugin_name in host.__plugins_order__: - settings["order"].append(plugin_name) - - return settings - - -def apply_settings(settings: dict): - """将settings.json数据应用到变量中""" - if "order" in settings: - host.__plugins_order__ = settings["order"] - - if "functions" in settings: - if "enabled" in settings["functions"]: - host.__enable_content_functions__ = settings["functions"]["enabled"] - # logging.debug("set content function enabled: {}".format(host.__enable_content_functions__)) - - -def dump_settings(): - """保存settings.json数据""" - logging.debug("保存plugins/settings.json数据") - - settings = wrapper_dict_from_runtime_context() - - with open("plugins/settings.json", "w", encoding="utf-8") as f: - json.dump(settings, f, indent=4, ensure_ascii=False) - - -def load_settings(): - """加载settings.json数据""" - logging.debug("加载plugins/settings.json数据") - - # 读取plugins/settings.json - settings = { - } - - # 检查文件是否存在 - if not os.path.exists("plugins/settings.json"): - # 不存在则创建 - with open("plugins/settings.json", "w", encoding="utf-8") as f: - json.dump(wrapper_dict_from_runtime_context(), f, indent=4, ensure_ascii=False) - - with open("plugins/settings.json", "r", encoding="utf-8") as f: - settings = json.load(f) - - if settings is None: - settings = { - } - - # 检查每个设置项 - if "order" not in settings: - settings["order"] = [] - - settings_modified = False - - settings_copy = settings.copy() - - # 检查settings中多余的插件项 - - # order - for plugin_name in settings_copy["order"]: - if plugin_name not in host.__plugins_order__: - settings["order"].remove(plugin_name) - settings_modified = True - - # 检查settings中缺少的插件项 - - # order - for plugin_name in host.__plugins_order__: - if plugin_name not in settings_copy["order"]: - settings["order"].append(plugin_name) - settings_modified = True - - if "functions" not in settings: - settings["functions"] = { - "enabled": host.__enable_content_functions__ - } - settings_modified = True - elif "enabled" not in settings["functions"]: - settings["functions"]["enabled"] = host.__enable_content_functions__ - settings_modified = True - - logging.info("已全局{}内容函数。".format("启用" if settings["functions"]["enabled"] else "禁用")) - - apply_settings(settings) - - if settings_modified: - dump_settings() diff --git a/pkg/plugin/switch.py b/pkg/plugin/switch.py deleted file mode 100644 index ccc96c8c..00000000 --- a/pkg/plugin/switch.py +++ /dev/null @@ -1,94 +0,0 @@ -# 控制插件的开关 -import json -import logging -import os - -from ..plugin import host - - -def wrapper_dict_from_plugin_list() -> dict: - """将插件列表转换为开关json""" - switch = {} - - for plugin_name in host.__plugins__: - plugin = host.__plugins__[plugin_name] - - switch[plugin_name] = { - "path": plugin["path"], - "enabled": plugin["enabled"], - } - - return switch - - -def apply_switch(switch: dict): - """将开关数据应用到插件列表中""" - # print("将开关数据应用到插件列表中") - # print(switch) - for plugin_name in switch: - host.__plugins__[plugin_name]["enabled"] = switch[plugin_name]["enabled"] - - # 查找此插件的所有内容函数 - for func in host.__callable_functions__: - if func['name'].startswith(plugin_name + '-'): - func['enabled'] = switch[plugin_name]["enabled"] - - -def dump_switch(): - """保存开关数据""" - logging.debug("保存开关数据") - # 将开关数据写入plugins/switch.json - - switch = wrapper_dict_from_plugin_list() - - with open("plugins/switch.json", "w", encoding="utf-8") as f: - json.dump(switch, f, indent=4, ensure_ascii=False) - - -def load_switch(): - """加载开关数据""" - logging.debug("加载开关数据") - # 读取plugins/switch.json - - switch = {} - - # 检查文件是否存在 - if not os.path.exists("plugins/switch.json"): - # 不存在则创建 - with open("plugins/switch.json", "w", encoding="utf-8") as f: - json.dump(switch, f, indent=4, ensure_ascii=False) - - with open("plugins/switch.json", "r", encoding="utf-8") as f: - switch = json.load(f) - - if switch is None: - switch = {} - - switch_modified = False - - switch_copy = switch.copy() - # 检查switch中多余的和path不相符的 - for plugin_name in switch_copy: - if plugin_name not in host.__plugins__: - del switch[plugin_name] - switch_modified = True - elif switch[plugin_name]["path"] != host.__plugins__[plugin_name]["path"]: - # 删除此不相符的 - del switch[plugin_name] - switch_modified = True - - # 检查plugin中多余的 - for plugin_name in host.__plugins__: - if plugin_name not in switch: - switch[plugin_name] = { - "path": host.__plugins__[plugin_name]["path"], - "enabled": host.__plugins__[plugin_name]["enabled"], - } - switch_modified = True - - # 应用开关数据 - apply_switch(switch) - - # 如果switch有修改,保存 - if switch_modified: - dump_switch() diff --git a/pkg/provider/api/chat_completion.py b/pkg/provider/api/chat_completion.py deleted file mode 100644 index 1e0e1bc5..00000000 --- a/pkg/provider/api/chat_completion.py +++ /dev/null @@ -1,232 +0,0 @@ -import json -import logging - -import openai -from openai.types.chat import chat_completion_message - -from .model import RequestBase -from .. import funcmgr -from ...plugin import host -from ...utils import context - - -class ChatCompletionRequest(RequestBase): - """调用ChatCompletion接口的请求类。 - - 此类保证每一次返回的角色为assistant的信息的finish_reason一定为stop。 - 若有函数调用响应,本类的返回瀑布是:函数调用请求->函数调用结果->...->assistant的信息->stop。 - """ - - model: str - messages: list[dict[str, str]] - kwargs: dict - - stopped: bool = False - - pending_func_call: chat_completion_message.FunctionCall = None - - pending_msg: str - - def flush_pending_msg(self): - self.append_message( - role="assistant", - content=self.pending_msg - ) - self.pending_msg = "" - - def append_message(self, role: str, content: str, name: str=None, function_call: dict=None): - msg = { - "role": role, - "content": content - } - - if name is not None: - msg['name'] = name - - if function_call is not None: - msg['function_call'] = function_call - - self.messages.append(msg) - - def __init__( - self, - client: openai.Client, - model: str, - messages: list[dict[str, str]], - **kwargs - ): - self.client = client - self.model = model - self.messages = messages.copy() - - self.kwargs = kwargs - - self.req_func = self.client.chat.completions.create - - self.pending_func_call = None - - self.stopped = False - - self.pending_msg = "" - - def __iter__(self): - return self - - def __next__(self) -> dict: - if self.stopped: - raise StopIteration() - - if self.pending_func_call is None: # 没有待处理的函数调用请求 - - args = { - "model": self.model, - "messages": self.messages, - } - - funcs = funcmgr.get_func_schema_list() - - if len(funcs) > 0: - args['functions'] = funcs - - # 拼接kwargs - args = {**args, **self.kwargs} - - from openai.types.chat import chat_completion - - resp: chat_completion.ChatCompletion = self._req(**args) - - choice0 = resp.choices[0] - - # 如果不是函数调用,且finish_reason为stop,则停止迭代 - if choice0.finish_reason == 'stop': # and choice0["finish_reason"] == "stop" - self.stopped = True - - if hasattr(choice0.message, 'function_call') and choice0.message.function_call is not None: - self.pending_func_call = choice0.message.function_call - - self.append_message( - role="assistant", - content=choice0.message.content, - function_call=choice0.message.function_call - ) - - return { - "id": resp.id, - "choices": [ - { - "index": choice0.index, - "message": { - "role": "assistant", - "type": "function_call", - "content": choice0.message.content, - "function_call": { - "name": choice0.message.function_call.name, - "arguments": choice0.message.function_call.arguments - } - }, - "finish_reason": "function_call" - } - ], - "usage": { - "prompt_tokens": resp.usage.prompt_tokens, - "completion_tokens": resp.usage.completion_tokens, - "total_tokens": resp.usage.total_tokens - } - } - else: - - # self.pending_msg += choice0['message']['content'] - # 普通回复一定处于最后方,故不用再追加进内部messages - - return { - "id": resp.id, - "choices": [ - { - "index": choice0.index, - "message": { - "role": "assistant", - "type": "text", - "content": choice0.message.content - }, - "finish_reason": choice0.finish_reason - } - ], - "usage": { - "prompt_tokens": resp.usage.prompt_tokens, - "completion_tokens": resp.usage.completion_tokens, - "total_tokens": resp.usage.total_tokens - } - } - else: # 处理函数调用请求 - - cp_pending_func_call = self.pending_func_call.copy() - - self.pending_func_call = None - - func_name = cp_pending_func_call.name - arguments = {} - - try: - - try: - arguments = json.loads(cp_pending_func_call.arguments) - # 若不是json格式的异常处理 - except json.decoder.JSONDecodeError: - # 获取函数的参数列表 - func_schema = funcmgr.get_func_schema(func_name) - - arguments = { - func_schema['parameters']['required'][0]: cp_pending_func_call.arguments - } - - logging.info("执行函数调用: name={}, arguments={}".format(func_name, arguments)) - - # 执行函数调用 - ret = "" - try: - ret = funcmgr.execute_function(func_name, arguments) - - logging.info("函数执行完成。") - except Exception as e: - ret = "error: execute function failed: {}".format(str(e)) - logging.error("函数执行失败: {}".format(str(e))) - - # 上报数据 - plugin_info = host.get_plugin_info_for_audit(func_name.split('-')[0]) - audit_func_name = func_name.split('-')[1] - audit_func_desc = funcmgr.get_func_schema(func_name)['description'] - context.get_center_v2_api().usage.post_function_record( - plugin=plugin_info, - function_name=audit_func_name, - function_description=audit_func_desc, - ) - - self.append_message( - role="function", - content=json.dumps(ret, ensure_ascii=False), - name=func_name - ) - - return { - "id": -1, - "choices": [ - { - "index": -1, - "message": { - "role": "function", - "type": "function_return", - "function_name": func_name, - "content": json.dumps(ret, ensure_ascii=False) - }, - "finish_reason": "function_return" - } - ], - "usage": { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0 - } - } - - except funcmgr.ContentFunctionNotFoundError: - raise Exception("没有找到函数: {}".format(func_name)) diff --git a/pkg/provider/api/completion.py b/pkg/provider/api/completion.py deleted file mode 100644 index d14e91f4..00000000 --- a/pkg/provider/api/completion.py +++ /dev/null @@ -1,100 +0,0 @@ -import openai -from openai.types import completion, completion_choice - -from . import model - - -class CompletionRequest(model.RequestBase): - """调用Completion接口的请求类。 - - 调用方可以一直next completion直到finish_reason为stop。 - """ - - model: str - prompt: str - kwargs: dict - - stopped: bool = False - - def __init__( - self, - client: openai.Client, - model: str, - messages: list[dict[str, str]], - **kwargs - ): - self.client = client - self.model = model - self.prompt = "" - - for message in messages: - self.prompt += message["role"] + ": " + message["content"] + "\n" - - self.prompt += "assistant: " - - self.kwargs = kwargs - - self.req_func = self.client.completions.create - - def __iter__(self): - return self - - def __next__(self) -> dict: - """调用Completion接口,返回生成的文本 - - { - "id": "id", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "type": "text", - "content": "message" - }, - "finish_reason": "reason" - } - ], - "usage": { - "prompt_tokens": 10, - "completion_tokens": 20, - "total_tokens": 30 - } - } - """ - - if self.stopped: - raise StopIteration() - - resp: completion.Completion = self._req( - model=self.model, - prompt=self.prompt, - **self.kwargs - ) - - if resp.choices[0].finish_reason == "stop": - self.stopped = True - - choice0: completion_choice.CompletionChoice = resp.choices[0] - - self.prompt += choice0.text - - return { - "id": resp.id, - "choices": [ - { - "index": choice0.index, - "message": { - "role": "assistant", - "type": "text", - "content": choice0.text - }, - "finish_reason": choice0.finish_reason - } - ], - "usage": { - "prompt_tokens": resp.usage.prompt_tokens, - "completion_tokens": resp.usage.completion_tokens, - "total_tokens": resp.usage.total_tokens - } - } diff --git a/pkg/provider/api/model.py b/pkg/provider/api/model.py deleted file mode 100644 index 0a1f6a3a..00000000 --- a/pkg/provider/api/model.py +++ /dev/null @@ -1,40 +0,0 @@ -# 定义不同接口请求的模型 -import logging - -import openai - -from ...utils import context - - -class RequestBase: - - client: openai.Client - - req_func: callable - - def __init__(self, *args, **kwargs): - raise NotImplementedError - - def _next_key(self): - switched, name = context.get_openai_manager().key_mgr.auto_switch() - logging.debug("切换api-key: switched={}, name={}".format(switched, name)) - self.client.api_key = context.get_openai_manager().key_mgr.get_using_key() - - def _req(self, **kwargs): - """处理代理问题""" - logging.debug("请求接口参数: %s", str(kwargs)) - config = context.get_config_manager().data - - ret = self.req_func(**kwargs) - logging.debug("接口请求返回:%s", str(ret)) - - if config['switch_strategy'] == 'active': - self._next_key() - - return ret - - def __iter__(self): - raise self - - def __next__(self): - raise NotImplementedError diff --git a/pkg/provider/requester/api.py b/pkg/provider/requester/api.py index 5dd0abf2..88d500e6 100644 --- a/pkg/provider/requester/api.py +++ b/pkg/provider/requester/api.py @@ -6,7 +6,6 @@ import typing from ...core import app from ...core import entities as core_entities from .. import entities as llm_entities -from ..session import entities as session_entities class LLMAPIRequester(metaclass=abc.ABCMeta): """LLM API请求器 @@ -24,7 +23,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): async def request( self, query: core_entities.Query, - conversation: session_entities.Conversation, + conversation: core_entities.Conversation, ) -> typing.AsyncGenerator[llm_entities.Message, None]: """请求 """ diff --git a/pkg/provider/requester/apis/chatcmpl.py b/pkg/provider/requester/apis/chatcmpl.py index 1e3da1ad..a565009e 100644 --- a/pkg/provider/requester/apis/chatcmpl.py +++ b/pkg/provider/requester/apis/chatcmpl.py @@ -10,7 +10,6 @@ import openai.types.chat.chat_completion as chat_completion from .. import api from ....core import entities as core_entities from ... import entities as llm_entities -from ...session import entities as session_entities class OpenAIChatCompletion(api.LLMAPIRequester): @@ -43,41 +42,18 @@ class OpenAIChatCompletion(api.LLMAPIRequester): async def _closure( self, req_messages: list[dict], - conversation: session_entities.Conversation, - user_text: str = None, - function_ret: str = None, + conversation: core_entities.Conversation, ) -> llm_entities.Message: self.client.api_key = conversation.use_model.token_mgr.get_token() args = self.ap.cfg_mgr.data["completion_api_params"].copy() args["model"] = conversation.use_model.name - tools = await self.ap.tool_mgr.generate_tools_for_openai(conversation) - # tools = [ - # { - # "type": "function", - # "function": { - # "name": "get_current_weather", - # "description": "Get the current weather in a given location", - # "parameters": { - # "type": "object", - # "properties": { - # "location": { - # "type": "string", - # "description": "The city and state, e.g. San Francisco, CA", - # }, - # "unit": { - # "type": "string", - # "enum": ["celsius", "fahrenheit"], - # }, - # }, - # "required": ["location"], - # }, - # }, - # } - # ] - if tools: - args["tools"] = tools + if conversation.use_model.tool_call_supported: + tools = await self.ap.tool_mgr.generate_tools_for_openai(conversation) + + if tools: + args["tools"] = tools # 设置此次请求中的messages messages = req_messages @@ -92,7 +68,7 @@ class OpenAIChatCompletion(api.LLMAPIRequester): return message async def request( - self, query: core_entities.Query, conversation: session_entities.Conversation + self, query: core_entities.Query, conversation: core_entities.Conversation ) -> typing.AsyncGenerator[llm_entities.Message, None]: """请求""" diff --git a/pkg/provider/requester/entities.py b/pkg/provider/requester/entities.py index adc86677..c003564f 100644 --- a/pkg/provider/requester/entities.py +++ b/pkg/provider/requester/entities.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import typing import pydantic from . import api -from . import token +from . import token, tokenizer class LLMModelInfo(pydantic.BaseModel): @@ -17,7 +19,9 @@ class LLMModelInfo(pydantic.BaseModel): requester: api.LLMAPIRequester - function_call_supported: typing.Optional[bool] = False + tokenizer: 'tokenizer.LLMTokenizer' + + tool_call_supported: typing.Optional[bool] = False class Config: arbitrary_types_allowed = True diff --git a/pkg/provider/requester/modelmgr.py b/pkg/provider/requester/modelmgr.py index 7e6a3b52..40bf313e 100644 --- a/pkg/provider/requester/modelmgr.py +++ b/pkg/provider/requester/modelmgr.py @@ -5,6 +5,7 @@ from ...core import app from .apis import chatcmpl from . import token +from .tokenizers import tiktoken class ModelManager: @@ -17,25 +18,28 @@ class ModelManager: self.ap = ap self.model_list = [] + async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: + """通过名称获取模型 + """ + for model in self.model_list: + if model.name == name: + return model + raise ValueError(f"Model {name} not found") + async def initialize(self): openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap) await openai_chat_completion.initialize() openai_token_mgr = token.TokenManager(self.ap, list(self.ap.cfg_mgr.data['openai_config']['api_key'].values())) + tiktoken_tokenizer = tiktoken.Tiktoken(self.ap) + self.model_list.append( entities.LLMModelInfo( name="gpt-3.5-turbo", provider="openai", token_mgr=openai_token_mgr, requester=openai_chat_completion, - function_call_supported=True + tool_call_supported=True, + tokenizer=tiktoken_tokenizer ) ) - - async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: - """通过名称获取模型 - """ - for model in self.model_list: - if model.name == name: - return model - raise ValueError(f"Model {name} not found") \ No newline at end of file diff --git a/pkg/provider/requester/tokenizer.py b/pkg/provider/requester/tokenizer.py new file mode 100644 index 00000000..5af8a733 --- /dev/null +++ b/pkg/provider/requester/tokenizer.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import abc +import typing + +from ...core import app +from .. import entities as llm_entities +from . import entities + + +class LLMTokenizer(metaclass=abc.ABCMeta): + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + """初始化分词器 + """ + pass + + @abc.abstractmethod + async def count_token( + self, + messages: list[llm_entities.Message], + model: entities.LLMModelInfo + ) -> int: + pass diff --git a/pkg/provider/requester/tokenizers/__init__.py b/pkg/provider/requester/tokenizers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/provider/requester/tokenizers/tiktoken.py b/pkg/provider/requester/tokenizers/tiktoken.py new file mode 100644 index 00000000..3b83a144 --- /dev/null +++ b/pkg/provider/requester/tokenizers/tiktoken.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import tiktoken + +from .. import tokenizer +from ... import entities as llm_entities +from .. import entities + + +class Tiktoken(tokenizer.LLMTokenizer): + + async def count_token( + self, + messages: list[llm_entities.Message], + model: entities.LLMModelInfo + ) -> int: + try: + encoding = tiktoken.encoding_for_model(model.name) + except KeyError: + print("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + + num_tokens = 0 + for message in messages: + num_tokens += len(encoding.encode(message.role)) + num_tokens += len(encoding.encode(message.content)) + num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> + return num_tokens diff --git a/pkg/provider/session/entities.py b/pkg/provider/session/entities.py deleted file mode 100644 index cbeb72a3..00000000 --- a/pkg/provider/session/entities.py +++ /dev/null @@ -1,53 +0,0 @@ -from __future__ import annotations - -import datetime -import asyncio -import typing - -import pydantic - -from ..sysprompt import entities as sysprompt_entities -from .. import entities as llm_entities -from ..requester import entities -from ...core import entities as core_entities -from ..tools import entities as tools_entities - - -class Conversation(pydantic.BaseModel): - """对话""" - - prompt: sysprompt_entities.Prompt - - messages: list[llm_entities.Message] - - create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) - - update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) - - use_model: entities.LLMModelInfo - - use_funcs: typing.Optional[list[tools_entities.LLMFunction]] - - -class Session(pydantic.BaseModel): - """会话""" - launcher_type: core_entities.LauncherTypes - - launcher_id: int - - sender_id: typing.Optional[int] = 0 - - use_prompt_name: typing.Optional[str] = 'default' - - using_conversation: typing.Optional[Conversation] = None - - conversations: typing.Optional[list[Conversation]] = [] - - create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) - - update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) - - semaphore: typing.Optional[asyncio.Semaphore] = None - - class Config: - arbitrary_types_allowed = True diff --git a/pkg/provider/session/sessionmgr.py b/pkg/provider/session/sessionmgr.py index a1d5d4d9..a20e2b52 100644 --- a/pkg/provider/session/sessionmgr.py +++ b/pkg/provider/session/sessionmgr.py @@ -3,14 +3,13 @@ from __future__ import annotations import asyncio from ...core import app, entities as core_entities -from . import entities class SessionManager: ap: app.Application - session_list: list[entities.Session] + session_list: list[core_entities.Session] def __init__(self, ap: app.Application): self.ap = ap @@ -19,14 +18,14 @@ class SessionManager: async def initialize(self): pass - async def get_session(self, query: core_entities.Query) -> entities.Session: + async def get_session(self, query: core_entities.Query) -> core_entities.Session: """获取会话 """ for session in self.session_list: if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id: return session - session = entities.Session( + session = core_entities.Session( launcher_type=query.launcher_type, launcher_id=query.launcher_id, semaphore=asyncio.Semaphore(1) if self.ap.cfg_mgr.data['wait_last_done'] else asyncio.Semaphore(10000), @@ -34,12 +33,12 @@ class SessionManager: self.session_list.append(session) return session - async def get_conversation(self, session: entities.Session) -> entities.Conversation: + async def get_conversation(self, session: core_entities.Session) -> core_entities.Conversation: if not session.conversations: session.conversations = [] if session.using_conversation is None: - conversation = entities.Conversation( + conversation = core_entities.Conversation( prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name), messages=[], use_model=await self.ap.model_mgr.get_model_by_name(self.ap.cfg_mgr.data['completion_api_params']['model']), diff --git a/pkg/provider/tools/entities.py b/pkg/provider/tools/entities.py index b79627e5..52867291 100644 --- a/pkg/provider/tools/entities.py +++ b/pkg/provider/tools/entities.py @@ -6,6 +6,8 @@ import asyncio import pydantic +from ...core import entities as core_entities + class LLMFunction(pydantic.BaseModel): """函数""" diff --git a/pkg/provider/tools/toolmgr.py b/pkg/provider/tools/toolmgr.py index cc160e39..56b1d17b 100644 --- a/pkg/provider/tools/toolmgr.py +++ b/pkg/provider/tools/toolmgr.py @@ -4,7 +4,6 @@ import typing from ...core import app, entities as core_entities from . import entities -from ..session import entities as session_entities class ToolManager: @@ -12,8 +11,6 @@ class ToolManager: """ ap: app.Application - - all_functions: list[entities.LLMFunction] def __init__(self, ap: app.Application): self.ap = ap @@ -22,30 +19,10 @@ class ToolManager: async def initialize(self): pass - def register_legacy_function(self, name: str, description: str, parameters: dict, func: callable): - """注册函数 - """ - async def wrapper(query, **kwargs): - return func(**kwargs) - function = entities.LLMFunction( - name=name, - description=description, - human_desc='', - enable=True, - parameters=parameters, - func=wrapper - ) - self.all_functions.append(function) - - async def register_function(self, function: entities.LLMFunction): - """添加函数 - """ - self.all_functions.append(function) - async def get_function(self, name: str) -> entities.LLMFunction: """获取函数 """ - for function in self.all_functions: + for function in await self.get_all_functions(): if function.name == name: return function return None @@ -53,9 +30,14 @@ class ToolManager: async def get_all_functions(self) -> list[entities.LLMFunction]: """获取所有函数 """ - return self.all_functions + all_functions: list[entities.LLMFunction] = [] - async def generate_tools_for_openai(self, conversation: session_entities.Conversation) -> str: + for plugin in self.ap.plugin_mgr.plugins: + all_functions.extend(plugin.content_functions) + + return all_functions + + async def generate_tools_for_openai(self, conversation: core_entities.Conversation) -> str: """生成函数列表 """ tools = [] diff --git a/pkg/utils/center/groups/main.py b/pkg/utils/center/groups/main.py index a4e5414a..2edbb88e 100644 --- a/pkg/utils/center/groups/main.py +++ b/pkg/utils/center/groups/main.py @@ -1,17 +1,20 @@ from __future__ import annotations from .. import apigroup -from ... import context +from ....core import app class V2MainDataAPI(apigroup.APIGroup): """主程序相关 数据API""" - def __init__(self, prefix: str): - super().__init__(prefix+"/main") + ap: app.Application + + def __init__(self, prefix: str, ap: app.Application): + self.ap = ap + super().__init__(prefix+"/usage") def do(self, *args, **kwargs): - config = context.get_config_manager().data + config = self.ap.cfg_mgr.data if not config['report_usage']: return None return super().do(*args, **kwargs) diff --git a/pkg/utils/center/groups/plugin.py b/pkg/utils/center/groups/plugin.py index c7881b9d..d00e5813 100644 --- a/pkg/utils/center/groups/plugin.py +++ b/pkg/utils/center/groups/plugin.py @@ -1,17 +1,20 @@ from __future__ import annotations +from ....core import app from .. import apigroup -from ... import context class V2PluginDataAPI(apigroup.APIGroup): """插件数据相关 API""" - def __init__(self, prefix: str): - super().__init__(prefix+"/plugin") + ap: app.Application + + def __init__(self, prefix: str, ap: app.Application): + self.ap = ap + super().__init__(prefix+"/usage") def do(self, *args, **kwargs): - config = context.get_config_manager().data + config = self.ap.cfg_mgr.data if not config['report_usage']: return None return super().do(*args, **kwargs) diff --git a/pkg/utils/center/groups/usage.py b/pkg/utils/center/groups/usage.py index f966add4..f98da649 100644 --- a/pkg/utils/center/groups/usage.py +++ b/pkg/utils/center/groups/usage.py @@ -1,17 +1,20 @@ from __future__ import annotations from .. import apigroup -from ... import context +from ....core import app class V2UsageDataAPI(apigroup.APIGroup): """使用量数据相关 API""" - def __init__(self, prefix: str): + ap: app.Application + + def __init__(self, prefix: str, ap: app.Application): + self.ap = ap super().__init__(prefix+"/usage") def do(self, *args, **kwargs): - config = context.get_config_manager().data + config = self.ap.cfg_mgr.data if not config['report_usage']: return None return super().do(*args, **kwargs) diff --git a/pkg/utils/center/v2.py b/pkg/utils/center/v2.py index 53594b51..70d51384 100644 --- a/pkg/utils/center/v2.py +++ b/pkg/utils/center/v2.py @@ -6,7 +6,7 @@ from . import apigroup from .groups import main from .groups import usage from .groups import plugin -from ...utils import context +from ...core import app BACKEND_URL = "https://api.qchatgpt.rockchin.top/api/v2" @@ -23,7 +23,7 @@ class V2CenterAPI: plugin: plugin.V2PluginDataAPI = None """插件 API 组""" - def __init__(self, basic_info: dict = None, runtime_info: dict = None): + def __init__(self, ap: app.Application, basic_info: dict = None, runtime_info: dict = None): """初始化""" logging.debug("basic_info: %s, runtime_info: %s", basic_info, runtime_info) @@ -31,8 +31,7 @@ class V2CenterAPI: apigroup.APIGroup._basic_info = basic_info apigroup.APIGroup._runtime_info = runtime_info - self.main = main.V2MainDataAPI(BACKEND_URL) - self.usage = usage.V2UsageDataAPI(BACKEND_URL) - self.plugin = plugin.V2PluginDataAPI(BACKEND_URL) - - context.set_center_v2_api(self) + self.main = main.V2MainDataAPI(BACKEND_URL, ap) + self.usage = usage.V2UsageDataAPI(BACKEND_URL, ap) + self.plugin = plugin.V2PluginDataAPI(BACKEND_URL, ap) + diff --git a/pkg/utils/network.py b/pkg/utils/network.py deleted file mode 100644 index a4498854..00000000 --- a/pkg/utils/network.py +++ /dev/null @@ -1,11 +0,0 @@ -from . import context - - -def wrapper_proxies() -> dict: - """获取代理""" - config = context.get_config_manager().data - - return { - "http": config['openai_config']['proxy'], - "https": config['openai_config']['proxy'] - } if 'proxy' in config['openai_config'] and (config['openai_config']['proxy'] is not None) else None diff --git a/pkg/utils/pkgmgr.py b/pkg/utils/pkgmgr.py index 741c8f48..8958b38b 100644 --- a/pkg/utils/pkgmgr.py +++ b/pkg/utils/pkgmgr.py @@ -1,27 +1,27 @@ from pip._internal import main as pipmain -from . import log +# from . import log def install(package): pipmain(['install', package]) - log.reset_logging() + # log.reset_logging() def install_upgrade(package): pipmain(['install', '--upgrade', package, "-i", "https://pypi.tuna.tsinghua.edu.cn/simple", "--trusted-host", "pypi.tuna.tsinghua.edu.cn"]) - log.reset_logging() + # log.reset_logging() def run_pip(params: list): pipmain(params) - log.reset_logging() + # log.reset_logging() def install_requirements(file): pipmain(['install', '-r', file, "-i", "https://pypi.tuna.tsinghua.edu.cn/simple", "--trusted-host", "pypi.tuna.tsinghua.edu.cn"]) - log.reset_logging() + # log.reset_logging() def ensure_dulwich(): diff --git a/pkg/utils/proxy.py b/pkg/utils/proxy.py new file mode 100644 index 00000000..1c5ee18c --- /dev/null +++ b/pkg/utils/proxy.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from ..core import app + + +class ProxyManager: + ap: app.Application + + forward_proxies: dict[str, str] + + def __init__(self, ap: app.Application): + self.ap = ap + + self.forward_proxies = {} + + async def initialize(self): + config = self.ap.cfg_mgr.data + + return ( + { + "http": config["openai_config"]["proxy"], + "https": config["openai_config"]["proxy"], + } + if "proxy" in config["openai_config"] + and (config["openai_config"]["proxy"] is not None) + else None + ) + + def get_forward_proxies(self) -> str: + return self.forward_proxies diff --git a/pkg/utils/updater.py b/pkg/utils/updater.py index ec6e93a8..f2d357f2 100644 --- a/pkg/utils/updater.py +++ b/pkg/utils/updater.py @@ -8,21 +8,6 @@ import time import requests from . import constants -from . import network -from . import context - - -def check_dulwich_closure(): - try: - import pkg.utils.pkgmgr - pkg.utils.pkgmgr.ensure_dulwich() - except: - pass - - try: - import dulwich - except ModuleNotFoundError: - raise Exception("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77") def is_newer(new_tag: str, old_tag: str): @@ -47,28 +32,6 @@ def is_newer(new_tag: str, old_tag: str): return new_tag != old_tag -def get_release_list() -> list: - """获取发行列表""" - rls_list_resp = requests.get( - url="https://api.github.com/repos/RockChinQ/QChatGPT/releases", - proxies=network.wrapper_proxies() - ) - - rls_list = rls_list_resp.json() - - return rls_list - - -def get_current_tag() -> str: - """获取当前tag""" - current_tag = constants.semantic_version - if os.path.exists("current_tag"): - with open("current_tag", "r") as f: - current_tag = f.read() - - return current_tag - - def compare_version_str(v0: str, v1: str) -> int: """比较两个版本号""" @@ -209,79 +172,3 @@ def update_all(cli: bool = False) -> bool: else: print("已更新到最新版本: {}\n更新日志:\n{}\n完整的更新日志请前往 https://github.com/RockChinQ/QChatGPT/releases 查看。请手动重启程序以使用新版本。".format(current_tag, "\n".join(rls_notes[:-1]))) return True - - -def is_repo(path: str) -> bool: - """检查是否是git仓库""" - check_dulwich_closure() - - from dulwich import porcelain - try: - porcelain.open_repo(path) - return True - except: - return False - - -def get_remote_url(repo_path: str) -> str: - """获取远程仓库地址""" - check_dulwich_closure() - - from dulwich import porcelain - repo = porcelain.open_repo(repo_path) - return str(porcelain.get_remote_repo(repo, "origin")[1]) - - -def get_current_version_info() -> str: - """获取当前版本信息""" - rls_list = get_release_list() - current_tag = get_current_tag() - for rls in rls_list: - if rls['tag_name'] == current_tag: - return rls['name'] + "\n" + rls['body'] - return "未知版本" - - -def is_new_version_available() -> bool: - """检查是否有新版本""" - # 从github获取release列表 - rls_list = get_release_list() - if rls_list is None: - return False - - # 获取当前版本 - current_tag = get_current_tag() - - # 检查是否有新版本 - latest_tag_name = "" - for rls in rls_list: - if latest_tag_name == "": - latest_tag_name = rls['tag_name'] - break - - return is_newer(latest_tag_name, current_tag) - - -def get_rls_notes() -> list: - """获取更新日志""" - # 从github获取release列表 - rls_list = get_release_list() - if rls_list is None: - return None - - # 获取当前版本 - current_tag = get_current_tag() - - # 检查是否有新版本 - rls_notes = [] - for rls in rls_list: - if rls['tag_name'] == current_tag: - break - - rls_notes.append(rls['name']) - - return rls_notes - - -if __name__ == "__main__": - update_all() diff --git a/pkg/utils/version.py b/pkg/utils/version.py new file mode 100644 index 00000000..7427ce9f --- /dev/null +++ b/pkg/utils/version.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import os + +import requests + +from ..core import app +from . import constants + + +class VersionManager: + + ap: app.Application + + def __init__( + self, + ap: app.Application + ): + self.ap = ap + + async def initialize( + self + ): + pass + + def get_current_version( + self + ) -> str: + current_tag = constants.semantic_version + if os.path.exists("current_tag"): + with open("current_tag", "r") as f: + current_tag = f.read() + + return current_tag + + async def get_current_version_info( + self + ) -> str: + + """获取当前版本信息""" + rls_list = await self.get_release_list() + current_tag = self.get_current_version() + for rls in rls_list: + if rls['tag_name'] == current_tag: + return rls['name'] + "\n" + rls['body'] + return "未知版本" + + async def get_release_list(self) -> list: + """获取发行列表""" + rls_list_resp = requests.get( + url="https://api.github.com/repos/RockChinQ/QChatGPT/releases", + proxies=self.ap.proxy_mgr.get_forward_proxies() + ) + + rls_list = rls_list_resp.json() + + return rls_list + + async def update_all(self): + pass + + async def is_new_version_available(self) -> bool: + """检查是否有新版本""" + # 从github获取release列表 + rls_list = await self.get_release_list() + if rls_list is None: + return False + + # 获取当前版本 + current_tag = self.get_current_version() + + # 检查是否有新版本 + latest_tag_name = "" + for rls in rls_list: + if latest_tag_name == "": + latest_tag_name = rls['tag_name'] + break + + return self.is_newer(latest_tag_name, current_tag) + + + def is_newer(self, new_tag: str, old_tag: str): + """判断版本是否更新,忽略第四位版本和第一位版本""" + if new_tag == old_tag: + return False + + new_tag = new_tag.split(".") + old_tag = old_tag.split(".") + + # 判断主版本是否相同 + if new_tag[0] != old_tag[0]: + return False + + if len(new_tag) < 4: + return True + + # 合成前三段,判断是否相同 + new_tag = ".".join(new_tag[:3]) + old_tag = ".".join(old_tag[:3]) + + return new_tag != old_tag + + + def compare_version_str(v0: str, v1: str) -> int: + """比较两个版本号""" + + # 删除版本号前的v + if v0.startswith("v"): + v0 = v0[1:] + if v1.startswith("v"): + v1 = v1[1:] + + v0:list = v0.split(".") + v1:list = v1.split(".") + + # 如果两个版本号节数不同,把短的后面用0补齐 + if len(v0) < len(v1): + v0.extend(["0"]*(len(v1)-len(v0))) + elif len(v0) > len(v1): + v1.extend(["0"]*(len(v0)-len(v1))) + + # 从高位向低位比较 + for i in range(len(v0)): + if int(v0[i]) > int(v1[i]): + return 1 + elif int(v0[i]) < int(v1[i]): + return -1 + + return 0 + diff --git a/res/templates/plugin-setting-template.json b/res/templates/plugin-setting-template.json new file mode 100644 index 00000000..1d807ed1 --- /dev/null +++ b/res/templates/plugin-setting-template.json @@ -0,0 +1,3 @@ +{ + "plugins": [] +} \ No newline at end of file