diff --git a/main.py b/main.py deleted file mode 100644 index 53e40853..00000000 --- a/main.py +++ /dev/null @@ -1,483 +0,0 @@ -import importlib -import json -import os -import shutil -import threading -import time - -import logging -import sys -import traceback -import asyncio - -sys.path.append(".") - - -def check_file(): - # 检查是否有banlist.py,如果没有就把banlist-template.py复制一份 - if not os.path.exists('banlist.py'): - shutil.copy('banlist-template.py', 'banlist.py') - - # 检查是否有sensitive.json - if not os.path.exists("sensitive.json"): - shutil.copy("res/templates/sensitive-template.json", "sensitive.json") - - # 检查是否有scenario/default.json - if not os.path.exists("scenario/default.json"): - shutil.copy("scenario/default-template.json", "scenario/default.json") - - # 检查cmdpriv.json - if not os.path.exists("cmdpriv.json"): - shutil.copy("res/templates/cmdpriv-template.json", "cmdpriv.json") - - # 检查tips_custom - if not os.path.exists("tips.py"): - shutil.copy("tips-custom-template.py", "tips.py") - - # 检查temp目录 - if not os.path.exists("temp/"): - os.mkdir("temp/") - - # 检查并创建plugins、prompts目录 - check_path = ["plugins", "prompts"] - for path in check_path: - if not os.path.exists(path): - os.mkdir(path) - - # 配置文件存在性校验 - if not os.path.exists('config.py'): - shutil.copy('config-template.py', 'config.py') - print('请先在config.py中填写配置') - sys.exit(0) - - -# 初始化相关文件 -check_file() - -from pkg.utils.log import init_runtime_log_file, reset_logging -from pkg.config import manager as config_mgr -from pkg.config.impls import pymodule as pymodule_cfg - - -try: - import colorlog -except ImportError: - # 尝试安装 - import pkg.utils.pkgmgr as pkgmgr - try: - pkgmgr.install_requirements("requirements.txt") - import colorlog - except ImportError: - print("依赖不满足,请查看 https://github.com/RockChinQ/qcg-installer/issues/15") - sys.exit(1) -import colorlog - -import requests -import websockets.exceptions -from urllib3.exceptions import InsecureRequestWarning -import pkg.utils.context - - -# 是否使用override.json覆盖配置 -# 仅在启动时提供 --override 或 -r 参数时生效 -use_override = False - -def ensure_dependencies(): - import pkg.utils.pkgmgr as pkgmgr - pkgmgr.run_pip(["install", "openai", "Pillow", "nakuru-project-idk", "CallingGPT", "tiktoken", "--upgrade", - "-i", "https://pypi.tuna.tsinghua.edu.cn/simple", - "--trusted-host", "pypi.tuna.tsinghua.edu.cn"]) - - -known_exception_caught = False - - -def override_config_manager(): - config = pkg.utils.context.get_config_manager().data - - if os.path.exists("override.json") and use_override: - override_json = json.load(open("override.json", "r", encoding="utf-8")) - overrided = [] - for key in override_json: - if key in config: - config[key] = override_json[key] - # logging.info("覆写配置[{}]为[{}]".format(key, override_json[key])) - overrided.append(key) - else: - logging.error("无法覆写配置[{}]为[{}],该配置不存在,请检查override.json是否正确".format(key, override_json[key])) - if len(overrided) > 0: - logging.info("已根据override.json覆写配置项: {}".format(", ".join(overrided))) - - -def complete_tips(): - """根据tips-custom-template模块补全tips模块的属性""" - non_exist_keys = [] - - is_integrity = True - logging.debug("检查tips模块完整性.") - tips_template = importlib.import_module('tips-custom-template') - tips = importlib.import_module('tips') - for key in dir(tips_template): - if not key.startswith("__") and not hasattr(tips, key): - setattr(tips, key, getattr(tips_template, key)) - # logging.warning("[{}]不存在".format(key)) - non_exist_keys.append(key) - is_integrity = False - - if not is_integrity: - logging.warning("以下提示语字段不存在: {}".format(", ".join(non_exist_keys))) - logging.warning("tips模块不完整,您可以依据tips-custom-template.py检查tips.py") - logging.warning("以上配置已被设为默认值,将在3秒后继续启动... ") - time.sleep(3) - - -async def start_process(first_time_init=False): - """启动流程,reload之后会被执行""" - - global known_exception_caught - import pkg.utils.context - - # 计算host和instance标识符 - import pkg.audit.identifier - pkg.audit.identifier.init() - - # 加载配置 - cfg_inst: pymodule_cfg.PythonModuleConfigFile = pymodule_cfg.PythonModuleConfigFile( - 'config.py', - 'config-template.py' - ) - await config_mgr.ConfigManager(cfg_inst).load_config() - - override_config_manager() - - # 检查tips模块 - complete_tips() - - cfg = pkg.utils.context.get_config_manager().data - - # 更新openai库到最新版本 - if 'upgrade_dependencies' not in cfg or cfg['upgrade_dependencies']: - print("正在更新依赖库,请等待...") - if 'upgrade_dependencies' not in cfg: - print("这个操作不是必须的,如果不想更新,请在config.py中添加upgrade_dependencies=False") - else: - print("这个操作不是必须的,如果不想更新,请在config.py中将upgrade_dependencies设置为False") - try: - ensure_dependencies() - except Exception as e: - print("更新openai库失败:{}, 请忽略或自行更新".format(e)) - - known_exception_caught = False - try: - try: - - sh = reset_logging() - pkg.utils.context.context['logger_handler'] = sh - - # 初始化文字转图片 - from pkg.utils import text2img - text2img.initialize() - - # 检查是否设置了管理员 - if cfg['admin_qq'] == 0: - # logging.warning("未设置管理员QQ,管理员权限命令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段") - while True: - try: - cfg['admin_qq'] = int(input("未设置管理员QQ,管理员权限命令及运行告警将无法使用,请输入管理员QQ号: ")) - # 写入到文件 - - # 读取文件 - config_file_str = "" - with open("config.py", "r", encoding="utf-8") as f: - config_file_str = f.read() - # 替换 - config_file_str = config_file_str.replace("admin_qq = 0", "admin_qq = " + str(cfg['admin_qq'])) - # 写入 - with open("config.py", "w", encoding="utf-8") as f: - f.write(config_file_str) - - print("管理员QQ已设置,如需修改请修改config.py中的admin_qq字段") - time.sleep(4) - break - except ValueError: - print("请输入数字") - - # 初始化中央服务器 API 交互实例 - from pkg.utils.center import apigroup - from pkg.utils.center import v2 as center_v2 - - center_v2_api = center_v2.V2CenterAPI( - basic_info={ - "host_id": pkg.audit.identifier.identifier['host_id'], - "instance_id": pkg.audit.identifier.identifier['instance_id'], - "semantic_version": pkg.utils.updater.get_current_tag(), - "platform": sys.platform, - }, - runtime_info={ - "admin_id": "{}".format(cfg['admin_qq']), - "msg_source": cfg['msg_source_adapter'], - } - ) - pkg.utils.context.set_center_v2_api(center_v2_api) - - import pkg.openai.manager - import pkg.database.manager - import pkg.openai.session - import pkg.qqbot.manager - import pkg.openai.dprompt - import pkg.qqbot.cmds.aamgr - - try: - pkg.openai.dprompt.register_all() - pkg.qqbot.cmds.aamgr.register_all() - pkg.qqbot.cmds.aamgr.apply_privileges() - except Exception as e: - logging.error(e) - traceback.print_exc() - - # 配置OpenAI proxy - import openai - openai.proxies = None # 先重置,因为重载后可能需要清除proxy - if "http_proxy" in cfg['openai_config'] and cfg['openai_config']["http_proxy"] is not None: - openai.proxies = { - "http": cfg['openai_config']["http_proxy"], - "https": cfg['openai_config']["http_proxy"] - } - - # 配置openai api_base - if "reverse_proxy" in cfg['openai_config'] and cfg['openai_config']["reverse_proxy"] is not None: - logging.debug("设置反向代理: "+cfg['openai_config']['reverse_proxy']) - openai.base_url = cfg['openai_config']["reverse_proxy"] - - # 主启动流程 - database = pkg.database.manager.DatabaseManager() - - database.initialize_database() - - openai_interact = pkg.openai.manager.OpenAIInteract(cfg['openai_config']['api_key']) - - # 加载所有未超时的session - pkg.openai.session.load_sessions() - - # 初始化qq机器人 - qqbot = pkg.qqbot.manager.QQBotManager(first_time_init=first_time_init) - - # 加载插件 - import pkg.plugin.host - pkg.plugin.host.load_plugins() - - pkg.plugin.host.initialize_plugins() - - if first_time_init: # 不是热重载之后的启动,则启动新的bot线程 - - import mirai.exceptions - - def run_bot_wrapper(): - global known_exception_caught - try: - logging.debug("使用账号: {}".format(qqbot.bot_account_id)) - qqbot.adapter.run_sync() - except TypeError as e: - if str(e).__contains__("argument 'debug'"): - logging.error( - "连接bot失败:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/82".format(e)) - known_exception_caught = True - elif str(e).__contains__("As of 3.10, the *loop*"): - logging.error( - "Websockets版本过低:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/5".format(e)) - known_exception_caught = True - - except websockets.exceptions.InvalidStatus as e: - logging.error( - "mirai-api-http端口无法使用:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/22".format( - e)) - known_exception_caught = True - except mirai.exceptions.NetworkError as e: - logging.error("连接mirai-api-http失败:{}, 请检查是否已按照文档启动mirai".format(e)) - known_exception_caught = True - except Exception as e: - if str(e).__contains__("404"): - logging.error( - "mirai-api-http端口无法使用:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/22".format( - e)) - known_exception_caught = True - elif str(e).__contains__("signal only works in main thread"): - logging.error( - "hypercorn异常:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/86".format( - e)) - known_exception_caught = True - elif str(e).__contains__("did not receive a valid HTTP"): - logging.error( - "mirai-api-http端口无法使用:{}, 解决方案: https://github.com/RockChinQ/QChatGPT/issues/22".format( - e)) - else: - import traceback - traceback.print_exc() - logging.error( - "捕捉到未知异常:{}, 请前往 https://github.com/RockChinQ/QChatGPT/issues 查找或提issue".format(e)) - known_exception_caught = True - raise e - finally: - time.sleep(12) - threading.Thread( - target=run_bot_wrapper - ).start() - except Exception as e: - traceback.print_exc() - if isinstance(e, KeyboardInterrupt): - logging.info("程序被用户中止") - sys.exit(0) - elif isinstance(e, SyntaxError): - logging.error("配置文件存在语法错误,请检查配置文件:\n1. 是否存在中文符号\n2. 是否已按照文件中的说明填写正确") - sys.exit(1) - else: - logging.error("初始化失败:{}".format(e)) - sys.exit(1) - finally: - # 判断若是Windows,输出选择模式可能会暂停程序的警告 - if os.name == 'nt': - time.sleep(2) - logging.info("您正在使用Windows系统,若命令行窗口处于“选择”模式,程序可能会被暂停,此时请右键点击窗口空白区域使其取消选择模式。") - - time.sleep(12) - - if first_time_init: - if not known_exception_caught: - if cfg['msg_source_adapter'] == "yirimirai": - logging.info("QQ: {}, MAH: {}".format(cfg['mirai_http_api_config']['qq'], cfg['mirai_http_api_config']['host']+":"+str(cfg['mirai_http_api_config']['port']))) - logging.critical('程序启动完成,如长时间未显示 "成功登录到账号xxxxx" ,并且不回复消息,解决办法(请勿到群里问): ' - 'https://github.com/RockChinQ/QChatGPT/issues/37') - elif cfg['msg_source_adapter'] == 'nakuru': - logging.info("host: {}, port: {}, http_port: {}".format(cfg['nakuru_config']['host'], cfg['nakuru_config']['port'], cfg['nakuru_config']['http_port'])) - logging.critical('程序启动完成,如长时间未显示 "Protocol: connected" ,并且不回复消息,请检查config.py中的nakuru_config是否正确') - else: - sys.exit(1) - else: - logging.info('热重载完成') - - # 发送赞赏码 - if cfg['encourage_sponsor_at_start'] \ - and pkg.utils.context.get_openai_manager().audit_mgr.get_total_text_length() >= 2048: - - logging.info("发送赞赏码") - from mirai import MessageChain, Plain, Image - import pkg.utils.constants - message_chain = MessageChain([ - Plain("自2022年12月初以来,开发者已经花费了大量时间和精力来维护本项目,如果您觉得本项目对您有帮助,欢迎赞赏开发者," - "以支持项目稳定运行😘"), - Image(base64=pkg.utils.constants.alipay_qr_b64), - Image(base64=pkg.utils.constants.wechat_qr_b64), - Plain("BTC: 3N4Azee63vbBB9boGv9Rjf4N5SocMe5eCq\nXMR: 89LS21EKQuDGkyQoe2nDupiuWXk4TVD6FALvSKv5owfmeJEPFpHeMsZLYtLiJ6GxLrhsRe5gMs6MyMSDn4GNQAse2Mae4KE\n\n"), - Plain("(本消息仅在启动时发送至管理员,如果您不想再看到此消息,请在config.py中将encourage_sponsor_at_start设置为False)") - ]) - pkg.utils.context.get_qqbot_manager().notify_admin_message_chain(message_chain) - - time.sleep(5) - import pkg.utils.updater - try: - if pkg.utils.updater.is_new_version_available(): - logging.info("新版本可用,请发送 !update 进行自动更新\n更新日志:\n{}".format("\n".join(pkg.utils.updater.get_rls_notes()))) - else: - # logging.info("当前已是最新版本") - pass - - except Exception as e: - logging.warning("检查更新失败:{}".format(e)) - - try: - import pkg.utils.announcement as announcement - new_announcement = announcement.fetch_new() - if len(new_announcement) > 0: - for announcement in new_announcement: - logging.critical("[公告]<{}> {}".format(announcement['time'], announcement['content'])) - - # 发送统计数据 - pkg.utils.context.get_center_v2_api().main.post_announcement_showed( - [announcement['id'] for announcement in new_announcement] - ) - - except Exception as e: - logging.warning("获取公告失败:{}".format(e)) - - return qqbot - -def stop(): - import pkg.qqbot.manager - import pkg.openai.session - try: - import pkg.plugin.host - pkg.plugin.host.unload_plugins() - - qqbot_inst = pkg.utils.context.get_qqbot_manager() - assert isinstance(qqbot_inst, pkg.qqbot.manager.QQBotManager) - - for session in pkg.openai.session.sessions: - logging.info('持久化session: %s', session) - pkg.openai.session.sessions[session].persistence() - pkg.utils.context.get_database_manager().close() - except Exception as e: - if not isinstance(e, KeyboardInterrupt): - raise e - - -def main(): - global use_override - # 检查是否携带了 --override 或 -r 参数 - if '--override' in sys.argv or '-r' in sys.argv: - use_override = True - - # 初始化logging - init_runtime_log_file() - pkg.utils.context.context['logger_handler'] = reset_logging() - - # 配置线程池 - from pkg.utils import ThreadCtl - thread_ctl = ThreadCtl( - sys_pool_num=8, - admin_pool_num=4, - user_pool_num=8 - ) - # 存进上下文 - pkg.utils.context.set_thread_ctl(thread_ctl) - - if len(sys.argv) > 1 and sys.argv[1] == 'update': - print("正在进行程序更新...") - import pkg.utils.updater as updater - updater.update_all(cli=True) - sys.exit(0) - - # 关闭urllib的http警告 - requests.packages.urllib3.disable_warnings(InsecureRequestWarning) - - def run_wrapper(): - asyncio.run(start_process(True)) - - pkg.utils.context.get_thread_ctl().submit_sys_task( - run_wrapper - ) - - # 主线程循环 - while True: - try: - time.sleep(0xFF) - except: - stop() - pkg.utils.context.get_thread_ctl().shutdown() - - launch_args = sys.argv.copy() - - if "--cov-report" not in launch_args: - import platform - if platform.system() == 'Windows': - cmd = "taskkill /F /PID {}".format(os.getpid()) - elif platform.system() in ['Linux', 'Darwin']: - cmd = "kill -9 {}".format(os.getpid()) - os.system(cmd) - else: - print("正常退出以生成覆盖率报告") - sys.exit(0) - - -if __name__ == '__main__': - main() - diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py index 73e14584..a9feafdf 100644 --- a/pkg/command/cmdmgr.py +++ b/pkg/command/cmdmgr.py @@ -3,8 +3,8 @@ from __future__ import annotations import typing from ..core import app, entities as core_entities -from ..openai import entities as llm_entities -from ..openai.session import entities as session_entities +from ..gai import entities as llm_entities +from ..gai.session import entities as session_entities from . import entities, operator, errors from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cfg, cmd, help, version, update diff --git a/pkg/command/entities.py b/pkg/command/entities.py index 7fba96e5..98312bda 100644 --- a/pkg/command/entities.py +++ b/pkg/command/entities.py @@ -6,7 +6,7 @@ import pydantic import mirai from ..core import app, entities as core_entities -from ..openai.session import entities as session_entities +from ..gai.session import entities as session_entities from . import errors, operator diff --git a/pkg/command/operator.py b/pkg/command/operator.py index af1a5d6e..299bb6c0 100644 --- a/pkg/command/operator.py +++ b/pkg/command/operator.py @@ -4,7 +4,7 @@ import typing import abc from ..core import app, entities as core_entities -from ..openai.session import entities as session_entities +from ..gai.session import entities as session_entities from . import entities diff --git a/pkg/core/app.py b/pkg/core/app.py index 77b7fa6d..9fd2830d 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -3,11 +3,11 @@ from __future__ import annotations import logging import asyncio -from ..qqbot import manager as qqbot_mgr -from ..openai.session import sessionmgr as llm_session_mgr -from ..openai.requester import modelmgr as llm_model_mgr -from ..openai.sysprompt import sysprompt as llm_prompt_mgr -from ..openai.tools import toolmgr as llm_tool_mgr +from ..im import manager as qqbot_mgr +from ..gai.session import sessionmgr as llm_session_mgr +from ..gai.requester import modelmgr as llm_model_mgr +from ..gai.sysprompt import sysprompt as llm_prompt_mgr +from ..gai.tools import toolmgr as llm_tool_mgr from ..config import manager as config_mgr from ..database import manager as database_mgr from ..utils.center import v2 as center_mgr diff --git a/pkg/core/boot.py b/pkg/core/boot.py index 8a07b130..9153573d 100644 --- a/pkg/core/boot.py +++ b/pkg/core/boot.py @@ -14,11 +14,11 @@ from . import controller from ..pipeline import stagemgr from ..audit import identifier from ..database import manager as db_mgr -from ..openai.session import sessionmgr as llm_session_mgr -from ..openai.requester import modelmgr as llm_model_mgr -from ..openai.sysprompt import sysprompt as llm_prompt_mgr -from ..openai.tools import toolmgr as llm_tool_mgr -from ..qqbot import manager as im_mgr +from ..gai.session import sessionmgr as llm_session_mgr +from ..gai.requester import modelmgr as llm_model_mgr +from ..gai.sysprompt import sysprompt as llm_prompt_mgr +from ..gai.tools import toolmgr as llm_tool_mgr +from ..im import manager as im_mgr from ..command import cmdmgr from ..plugin import host as plugin_host from ..utils.center import v2 as center_v2 diff --git a/pkg/database/__init__.py b/pkg/database/__init__.py deleted file mode 100644 index c40dc210..00000000 --- a/pkg/database/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -数据库操作封装 -""" \ No newline at end of file diff --git a/pkg/database/manager.py b/pkg/database/manager.py deleted file mode 100644 index c1153e8f..00000000 --- a/pkg/database/manager.py +++ /dev/null @@ -1,365 +0,0 @@ -""" -数据库管理模块 -""" -import hashlib -import json -import logging -import time - -import sqlite3 - -from ..utils import context - - -class DatabaseManager: - """封装数据库底层操作,并提供方法给上层使用""" - - conn = None - cursor = None - - def __init__(self, *args, **kwargs): - - self.reconnect() - - context.set_database_manager(self) - - # 连接到数据库文件 - def reconnect(self): - """连接到数据库""" - self.conn = sqlite3.connect('database.db', check_same_thread=False) - self.cursor = self.conn.cursor() - - def close(self): - self.conn.close() - - def __execute__(self, *args, **kwargs) -> sqlite3.Cursor: - # logging.debug('SQL: {}'.format(sql)) - logging.debug('SQL: {}'.format(args)) - c = self.cursor.execute(*args, **kwargs) - self.conn.commit() - return c - - # 初始化数据库的函数 - def initialize_database(self): - """创建数据表""" - - self.__execute__(""" - create table if not exists `sessions` ( - `id` INTEGER PRIMARY KEY AUTOINCREMENT, - `name` varchar(255) not null, - `type` varchar(255) not null, - `number` bigint not null, - `create_timestamp` bigint not null, - `last_interact_timestamp` bigint not null, - `status` varchar(255) not null default 'on_going', - `default_prompt` text not null default '', - `prompt` text not null, - `token_counts` text not null default '[]' - ) - """) - - # 检查sessions表是否存在`default_prompt`字段, 检查是否存在`token_counts`字段 - self.__execute__("PRAGMA table_info('sessions')") - columns = self.cursor.fetchall() - has_default_prompt = False - has_token_counts = False - for field in columns: - if field[1] == 'default_prompt': - has_default_prompt = True - if field[1] == 'token_counts': - has_token_counts = True - if has_default_prompt and has_token_counts: - break - if not has_default_prompt: - self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''") - if not has_token_counts: - self.__execute__("alter table `sessions` add column `token_counts` text not null default '[]'") - - - self.__execute__(""" - create table if not exists `account_fee`( - `id` INTEGER PRIMARY KEY AUTOINCREMENT, - `key_md5` varchar(255) not null, - `timestamp` bigint not null, - `fee` DECIMAL(12,6) not null - ) - """) - - self.__execute__(""" - create table if not exists `account_usage`( - `id` INTEGER PRIMARY KEY AUTOINCREMENT, - `json` text not null - ) - """) - # print('Database initialized.') - - # session持久化 - def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int, - last_interact_timestamp: int, prompt: str, default_prompt: str = '', token_counts: str = ''): - """持久化指定session""" - - # 检查是否已经有了此name和create_timestamp的session - # 如果有,就更新prompt和last_interact_timestamp - # 如果没有,就插入一条新的记录 - self.__execute__(""" - select count(*) from `sessions` where `type` = '{}' and `number` = {} and `create_timestamp` = {} - """.format(subject_type, subject_number, create_timestamp)) - count = self.cursor.fetchone()[0] - if count == 0: - - sql = """ - insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`, `token_counts`) - values (?, ?, ?, ?, ?, ?, ?, ?) - """ - - self.__execute__(sql, - ("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp, - last_interact_timestamp, prompt, default_prompt, token_counts)) - else: - sql = """ - update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?, `token_counts` = ? - where `type` = ? and `number` = ? and `create_timestamp` = ? - """ - - self.__execute__(sql, (last_interact_timestamp, prompt, token_counts, subject_type, - subject_number, create_timestamp)) - - # 显式关闭一个session - def explicit_close_session(self, session_name: str, create_timestamp: int): - self.__execute__(""" - update `sessions` set `status` = 'explicitly_closed' where `name` = '{}' and `create_timestamp` = {} - """.format(session_name, create_timestamp)) - - def set_session_ongoing(self, session_name: str, create_timestamp: int): - self.__execute__(""" - update `sessions` set `status` = 'on_going' where `name` = '{}' and `create_timestamp` = {} - """.format(session_name, create_timestamp)) - - # 设置session为过期 - def set_session_expired(self, session_name: str, create_timestamp: int): - self.__execute__(""" - update `sessions` set `status` = 'expired' where `name` = '{}' and `create_timestamp` = {} - """.format(session_name, create_timestamp)) - - # 从数据库加载还没过期的session数据 - def load_valid_sessions(self) -> dict: - # 从数据库中加载所有还没过期的session - config = context.get_config_manager().data - self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` - from `sessions` where `last_interact_timestamp` > {} - """.format(int(time.time()) - config['session_expire_time'])) - results = self.cursor.fetchall() - sessions = {} - for result in results: - session_name = result[0] - subject_type = result[1] - subject_number = result[2] - create_timestamp = result[3] - last_interact_timestamp = result[4] - prompt = result[5] - status = result[6] - default_prompt = result[7] - token_counts = result[8] - - # 当且仅当最后一个该对象的会话是on_going状态时,才会被加载 - if status == 'on_going': - sessions[session_name] = { - 'subject_type': subject_type, - 'subject_number': subject_number, - 'create_timestamp': create_timestamp, - 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt, - 'default_prompt': default_prompt, - 'token_counts': token_counts - } - else: - if session_name in sessions: - del sessions[session_name] - - return sessions - - # 获取此session_name前一个session的数据 - def last_session(self, session_name: str, cursor_timestamp: int): - - self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` - from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc - limit 1 - """.format(session_name, cursor_timestamp)) - results = self.cursor.fetchall() - if len(results) == 0: - return None - result = results[0] - - session_name = result[0] - subject_type = result[1] - subject_number = result[2] - create_timestamp = result[3] - last_interact_timestamp = result[4] - prompt = result[5] - status = result[6] - default_prompt = result[7] - token_counts = result[8] - - return { - 'subject_type': subject_type, - 'subject_number': subject_number, - 'create_timestamp': create_timestamp, - 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt, - 'default_prompt': default_prompt, - 'token_counts': token_counts - } - - # 获取此session_name后一个session的数据 - def next_session(self, session_name: str, cursor_timestamp: int): - - self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` - from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc - limit 1 - """.format(session_name, cursor_timestamp)) - results = self.cursor.fetchall() - if len(results) == 0: - return None - result = results[0] - - session_name = result[0] - subject_type = result[1] - subject_number = result[2] - create_timestamp = result[3] - last_interact_timestamp = result[4] - prompt = result[5] - status = result[6] - default_prompt = result[7] - token_counts = result[8] - - return { - 'subject_type': subject_type, - 'subject_number': subject_number, - 'create_timestamp': create_timestamp, - 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt, - 'default_prompt': default_prompt, - 'token_counts': token_counts - } - - # 列出与某个对象的所有对话session - def list_history(self, session_name: str, capacity: int, page: int): - self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` - from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {} - """.format(session_name, capacity, capacity * page)) - results = self.cursor.fetchall() - sessions = [] - for result in results: - session_name = result[0] - subject_type = result[1] - subject_number = result[2] - create_timestamp = result[3] - last_interact_timestamp = result[4] - prompt = result[5] - status = result[6] - default_prompt = result[7] - token_counts = result[8] - - sessions.append({ - 'subject_type': subject_type, - 'subject_number': subject_number, - 'create_timestamp': create_timestamp, - 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt, - 'default_prompt': default_prompt, - 'token_counts': token_counts - }) - - return sessions - - def delete_history(self, session_name: str, index: int) -> bool: - # 删除倒序第index个session - # 查找其id再删除 - self.__execute__(""" - delete from `sessions` where `id` in (select `id` from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit 1 offset {}) - """.format(session_name, index)) - - return self.cursor.rowcount == 1 - - def delete_all_history(self, session_name: str) -> bool: - self.__execute__(""" - delete from `sessions` where `name` = '{}' - """.format(session_name)) - return self.cursor.rowcount > 0 - - def delete_all_session_history(self) -> bool: - self.__execute__(""" - delete from `sessions` - """) - return self.cursor.rowcount > 0 - - # 将apikey的使用量存进数据库 - def dump_api_key_usage(self, api_keys: dict, usage: dict): - logging.debug('dumping api key usage...') - logging.debug(api_keys) - logging.debug(usage) - for api_key in api_keys: - # 计算key的md5值 - key_md5 = hashlib.md5(api_keys[api_key].encode('utf-8')).hexdigest() - # 获取使用量 - usage_count = 0 - if key_md5 in usage: - usage_count = usage[key_md5] - # 将使用量存进数据库 - # 先检查是否已存在 - self.__execute__(""" - select count(*) from `api_key_usage` where `key_md5` = '{}'""".format(key_md5)) - result = self.cursor.fetchone() - if result[0] == 0: - # 不存在则插入 - self.__execute__(""" - insert into `api_key_usage` (`key_md5`, `usage`,`timestamp`) values ('{}', {}, {}) - """.format(key_md5, usage_count, int(time.time()))) - else: - # 存在则更新,timestamp设置为当前 - self.__execute__(""" - update `api_key_usage` set `usage` = {}, `timestamp` = {} where `key_md5` = '{}' - """.format(usage_count, int(time.time()), key_md5)) - - def load_api_key_usage(self): - self.__execute__(""" - select `key_md5`, `usage` from `api_key_usage` - """) - results = self.cursor.fetchall() - usage = {} - for result in results: - key_md5 = result[0] - usage_count = result[1] - usage[key_md5] = usage_count - return usage - - def dump_usage_json(self, usage: dict): - - json_str = json.dumps(usage) - self.__execute__(""" - select count(*) from `account_usage`""") - result = self.cursor.fetchone() - if result[0] == 0: - # 不存在则插入 - self.__execute__(""" - insert into `account_usage` (`json`) values ('{}') - """.format(json_str)) - else: - # 存在则更新 - self.__execute__(""" - update `account_usage` set `json` = '{}' where `id` = 1 - """.format(json_str)) - - def load_usage_json(self): - self.__execute__(""" - select `json` from `account_usage` order by id desc limit 1 - """) - result = self.cursor.fetchone() - if result is None: - return None - else: - return result[0] diff --git a/pkg/openai/__init__.py b/pkg/gai/__init__.py similarity index 100% rename from pkg/openai/__init__.py rename to pkg/gai/__init__.py diff --git a/pkg/openai/api/__init__.py b/pkg/gai/api/__init__.py similarity index 100% rename from pkg/openai/api/__init__.py rename to pkg/gai/api/__init__.py diff --git a/pkg/openai/api/chat_completion.py b/pkg/gai/api/chat_completion.py similarity index 100% rename from pkg/openai/api/chat_completion.py rename to pkg/gai/api/chat_completion.py diff --git a/pkg/openai/api/completion.py b/pkg/gai/api/completion.py similarity index 100% rename from pkg/openai/api/completion.py rename to pkg/gai/api/completion.py diff --git a/pkg/openai/api/model.py b/pkg/gai/api/model.py similarity index 100% rename from pkg/openai/api/model.py rename to pkg/gai/api/model.py diff --git a/pkg/openai/entities.py b/pkg/gai/entities.py similarity index 100% rename from pkg/openai/entities.py rename to pkg/gai/entities.py diff --git a/pkg/openai/modelmgr.py b/pkg/gai/modelmgr.py similarity index 96% rename from pkg/openai/modelmgr.py rename to pkg/gai/modelmgr.py index 0abd2d16..69e64bed 100644 --- a/pkg/openai/modelmgr.py +++ b/pkg/gai/modelmgr.py @@ -8,9 +8,9 @@ Completion - text-davinci-003 等模型 import tiktoken import openai -from ..openai.api import model as api_model -from ..openai.api import completion as api_completion -from ..openai.api import chat_completion as api_chat_completion +from ..gai.api import model as api_model +from ..gai.api import completion as api_completion +from ..gai.api import chat_completion as api_chat_completion COMPLETION_MODELS = { "gpt-3.5-turbo-instruct", diff --git a/pkg/openai/requester/__init__.py b/pkg/gai/requester/__init__.py similarity index 100% rename from pkg/openai/requester/__init__.py rename to pkg/gai/requester/__init__.py diff --git a/pkg/openai/requester/api.py b/pkg/gai/requester/api.py similarity index 100% rename from pkg/openai/requester/api.py rename to pkg/gai/requester/api.py diff --git a/pkg/openai/requester/apis/__init__.py b/pkg/gai/requester/apis/__init__.py similarity index 100% rename from pkg/openai/requester/apis/__init__.py rename to pkg/gai/requester/apis/__init__.py diff --git a/pkg/openai/requester/apis/chatcmpl.py b/pkg/gai/requester/apis/chatcmpl.py similarity index 100% rename from pkg/openai/requester/apis/chatcmpl.py rename to pkg/gai/requester/apis/chatcmpl.py diff --git a/pkg/openai/requester/entities.py b/pkg/gai/requester/entities.py similarity index 100% rename from pkg/openai/requester/entities.py rename to pkg/gai/requester/entities.py diff --git a/pkg/openai/requester/modelmgr.py b/pkg/gai/requester/modelmgr.py similarity index 100% rename from pkg/openai/requester/modelmgr.py rename to pkg/gai/requester/modelmgr.py diff --git a/pkg/openai/requester/token.py b/pkg/gai/requester/token.py similarity index 100% rename from pkg/openai/requester/token.py rename to pkg/gai/requester/token.py diff --git a/pkg/openai/session/__init__.py b/pkg/gai/session/__init__.py similarity index 100% rename from pkg/openai/session/__init__.py rename to pkg/gai/session/__init__.py diff --git a/pkg/openai/session/entities.py b/pkg/gai/session/entities.py similarity index 100% rename from pkg/openai/session/entities.py rename to pkg/gai/session/entities.py diff --git a/pkg/openai/session/sessionmgr.py b/pkg/gai/session/sessionmgr.py similarity index 100% rename from pkg/openai/session/sessionmgr.py rename to pkg/gai/session/sessionmgr.py diff --git a/pkg/openai/sysprompt/__init__.py b/pkg/gai/sysprompt/__init__.py similarity index 100% rename from pkg/openai/sysprompt/__init__.py rename to pkg/gai/sysprompt/__init__.py diff --git a/pkg/openai/sysprompt/entities.py b/pkg/gai/sysprompt/entities.py similarity index 85% rename from pkg/openai/sysprompt/entities.py rename to pkg/gai/sysprompt/entities.py index 43cd3bf7..af190259 100644 --- a/pkg/openai/sysprompt/entities.py +++ b/pkg/gai/sysprompt/entities.py @@ -3,7 +3,7 @@ from __future__ import annotations import typing import pydantic -from ...openai import entities +from ...gai import entities class Prompt(pydantic.BaseModel): diff --git a/pkg/openai/sysprompt/loader.py b/pkg/gai/sysprompt/loader.py similarity index 100% rename from pkg/openai/sysprompt/loader.py rename to pkg/gai/sysprompt/loader.py diff --git a/pkg/openai/sysprompt/loaders/__init__.py b/pkg/gai/sysprompt/loaders/__init__.py similarity index 100% rename from pkg/openai/sysprompt/loaders/__init__.py rename to pkg/gai/sysprompt/loaders/__init__.py diff --git a/pkg/openai/sysprompt/loaders/scenario.py b/pkg/gai/sysprompt/loaders/scenario.py similarity index 95% rename from pkg/openai/sysprompt/loaders/scenario.py rename to pkg/gai/sysprompt/loaders/scenario.py index 4d54f30f..e0991ca7 100644 --- a/pkg/openai/sysprompt/loaders/scenario.py +++ b/pkg/gai/sysprompt/loaders/scenario.py @@ -5,7 +5,7 @@ import os from .. import loader from .. import entities -from ....openai import entities as llm_entities +from ....gai import entities as llm_entities class ScenarioPromptLoader(loader.PromptLoader): diff --git a/pkg/openai/sysprompt/loaders/single.py b/pkg/gai/sysprompt/loaders/single.py similarity index 96% rename from pkg/openai/sysprompt/loaders/single.py rename to pkg/gai/sysprompt/loaders/single.py index 1fff5a69..9a3df6b7 100644 --- a/pkg/openai/sysprompt/loaders/single.py +++ b/pkg/gai/sysprompt/loaders/single.py @@ -3,7 +3,7 @@ import os from .. import loader from .. import entities -from ....openai import entities as llm_entities +from ....gai import entities as llm_entities class SingleSystemPromptLoader(loader.PromptLoader): diff --git a/pkg/openai/sysprompt/sysprompt.py b/pkg/gai/sysprompt/sysprompt.py similarity index 100% rename from pkg/openai/sysprompt/sysprompt.py rename to pkg/gai/sysprompt/sysprompt.py diff --git a/pkg/openai/tools/__init__.py b/pkg/gai/tools/__init__.py similarity index 100% rename from pkg/openai/tools/__init__.py rename to pkg/gai/tools/__init__.py diff --git a/pkg/openai/tools/entities.py b/pkg/gai/tools/entities.py similarity index 100% rename from pkg/openai/tools/entities.py rename to pkg/gai/tools/entities.py diff --git a/pkg/openai/tools/toolmgr.py b/pkg/gai/tools/toolmgr.py similarity index 100% rename from pkg/openai/tools/toolmgr.py rename to pkg/gai/tools/toolmgr.py diff --git a/pkg/qqbot/__init__.py b/pkg/im/__init__.py similarity index 100% rename from pkg/qqbot/__init__.py rename to pkg/im/__init__.py diff --git a/pkg/qqbot/adapter.py b/pkg/im/adapter.py similarity index 100% rename from pkg/qqbot/adapter.py rename to pkg/im/adapter.py diff --git a/pkg/qqbot/manager.py b/pkg/im/manager.py similarity index 95% rename from pkg/qqbot/manager.py rename to pkg/im/manager.py index 12868b94..f6c8efe4 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/im/manager.py @@ -10,11 +10,11 @@ from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \ import mirai import func_timeout -from ..openai import session as openai_session +from ..gai import session as openai_session from ..utils import context import tips as tips_custom -from ..qqbot import adapter as msadapter +from ..im import adapter as msadapter from .ratelim import ratelim from ..core import app, entities as core_entities @@ -44,13 +44,13 @@ class QQBotManager: logging.debug("Use adapter:" + config['msg_source_adapter']) if config['msg_source_adapter'] == 'yirimirai': - from pkg.qqbot.sources.yirimirai import YiriMiraiAdapter + from pkg.im.sources.yirimirai import YiriMiraiAdapter mirai_http_api_config = config['mirai_http_api_config'] self.bot_account_id = config['mirai_http_api_config']['qq'] self.adapter = YiriMiraiAdapter(mirai_http_api_config) elif config['msg_source_adapter'] == 'nakuru': - from pkg.qqbot.sources.nakuru import NakuruProjectAdapter + from pkg.im.sources.nakuru import NakuruProjectAdapter self.adapter = NakuruProjectAdapter(config['nakuru_config']) self.bot_account_id = self.adapter.bot_account_id diff --git a/pkg/qqbot/ratelim/__init__.py b/pkg/im/ratelim/__init__.py similarity index 100% rename from pkg/qqbot/ratelim/__init__.py rename to pkg/im/ratelim/__init__.py diff --git a/pkg/qqbot/ratelim/algo.py b/pkg/im/ratelim/algo.py similarity index 100% rename from pkg/qqbot/ratelim/algo.py rename to pkg/im/ratelim/algo.py diff --git a/pkg/qqbot/ratelim/algos/__init__.py b/pkg/im/ratelim/algos/__init__.py similarity index 100% rename from pkg/qqbot/ratelim/algos/__init__.py rename to pkg/im/ratelim/algos/__init__.py diff --git a/pkg/qqbot/ratelim/algos/fixedwin.py b/pkg/im/ratelim/algos/fixedwin.py similarity index 100% rename from pkg/qqbot/ratelim/algos/fixedwin.py rename to pkg/im/ratelim/algos/fixedwin.py diff --git a/pkg/qqbot/ratelim/ratelim.py b/pkg/im/ratelim/ratelim.py similarity index 100% rename from pkg/qqbot/ratelim/ratelim.py rename to pkg/im/ratelim/ratelim.py diff --git a/pkg/qqbot/sources/__init__.py b/pkg/im/sources/__init__.py similarity index 100% rename from pkg/qqbot/sources/__init__.py rename to pkg/im/sources/__init__.py diff --git a/pkg/qqbot/sources/nakuru.py b/pkg/im/sources/nakuru.py similarity index 99% rename from pkg/qqbot/sources/nakuru.py rename to pkg/im/sources/nakuru.py index fd278f70..f0df8665 100644 --- a/pkg/qqbot/sources/nakuru.py +++ b/pkg/im/sources/nakuru.py @@ -9,7 +9,7 @@ import nakuru import nakuru.entities.components as nkc from .. import adapter as adapter_model -from ...qqbot import blob +from ...im import blob from ...utils import context diff --git a/pkg/qqbot/sources/yirimirai.py b/pkg/im/sources/yirimirai.py similarity index 100% rename from pkg/qqbot/sources/yirimirai.py rename to pkg/im/sources/yirimirai.py diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 889b3bb6..a7ac9f05 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -7,7 +7,7 @@ import mirai from .. import handler from ... import entities from ....core import entities as core_entities -from ....openai import entities as llm_entities +from ....gai import entities as llm_entities class ChatMessageHandler(handler.MessageHandler): diff --git a/pkg/plugin/host.py b/pkg/plugin/host.py index bf41003a..5e9ec9d7 100644 --- a/pkg/plugin/host.py +++ b/pkg/plugin/host.py @@ -15,7 +15,7 @@ from ..utils import network as network from ..utils import context as context from ..plugin import switch as switch from ..plugin import settings as settings -from ..qqbot import adapter as msadapter +from ..im import adapter as msadapter from ..plugin import metadata as metadata from mirai import Mirai diff --git a/pkg/utils/context.py b/pkg/utils/context.py deleted file mode 100644 index 9f201b81..00000000 --- a/pkg/utils/context.py +++ /dev/null @@ -1,128 +0,0 @@ -from __future__ import annotations - -import threading - -from ..database import manager as db_mgr -from ..qqbot import manager as qqbot_mgr -from ..config import manager as config_mgr -from ..plugin import host as plugin_host -from .center import v2 as center_v2 - - -context = { - 'inst': { - 'database.manager.DatabaseManager': None, - 'openai.manager.OpenAIInteract': None, - 'qqbot.manager.QQBotManager': None, - 'config.manager.ConfigManager': None, - }, - 'pool_ctl': None, - 'logger_handler': None, - 'config': None, - 'plugin_host': None, -} -context_lock = threading.Lock() - -### context耦合度非常高,需要大改 ### -def set_config(inst): - context_lock.acquire() - context['config'] = inst - context_lock.release() - - -def get_config(): - context_lock.acquire() - t = context['config'] - context_lock.release() - return t - - -def set_database_manager(inst: db_mgr.DatabaseManager): - context_lock.acquire() - context['inst']['database.manager.DatabaseManager'] = inst - context_lock.release() - - -def get_database_manager() -> db_mgr.DatabaseManager: - context_lock.acquire() - t = context['inst']['database.manager.DatabaseManager'] - context_lock.release() - return t - - -def set_openai_manager(inst: openai_mgr.OpenAIInteract): - context_lock.acquire() - context['inst']['openai.manager.OpenAIInteract'] = inst - context_lock.release() - - -def get_openai_manager() -> openai_mgr.OpenAIInteract: - context_lock.acquire() - t = context['inst']['openai.manager.OpenAIInteract'] - context_lock.release() - return t - - -def set_qqbot_manager(inst: qqbot_mgr.QQBotManager): - context_lock.acquire() - context['inst']['qqbot.manager.QQBotManager'] = inst - context_lock.release() - - -def get_qqbot_manager() -> qqbot_mgr.QQBotManager: - context_lock.acquire() - t = context['inst']['qqbot.manager.QQBotManager'] - context_lock.release() - return t - - -def set_config_manager(inst: config_mgr.ConfigManager): - context_lock.acquire() - context['inst']['config.manager.ConfigManager'] = inst - context_lock.release() - - -def get_config_manager() -> config_mgr.ConfigManager: - context_lock.acquire() - t = context['inst']['config.manager.ConfigManager'] - context_lock.release() - return t - - -def set_plugin_host(inst: plugin_host.PluginHost): - context_lock.acquire() - context['plugin_host'] = inst - context_lock.release() - - -def get_plugin_host() -> plugin_host.PluginHost: - context_lock.acquire() - t = context['plugin_host'] - context_lock.release() - return t - - -def set_thread_ctl(inst: threadctl.ThreadCtl): - context_lock.acquire() - context['pool_ctl'] = inst - context_lock.release() - - -def get_thread_ctl() -> threadctl.ThreadCtl: - context_lock.acquire() - t: threadctl.ThreadCtl = context['pool_ctl'] - context_lock.release() - return t - - -def set_center_v2_api(inst: center_v2.V2CenterAPI): - context_lock.acquire() - context['center_v2_api'] = inst - context_lock.release() - - -def get_center_v2_api() -> center_v2.V2CenterAPI: - context_lock.acquire() - t: center_v2.V2CenterAPI = context['center_v2_api'] - context_lock.release() - return t \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c3e29401..38560047 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ nakuru-project-idk CallingGPT tiktoken PyYaml -aiohttp \ No newline at end of file +aiohttp +pydantic