From eb4d63dd236e562660a999923bc750b0d995535e Mon Sep 17 00:00:00 2001 From: Rock Chin <1010553892@qq.com> Date: Thu, 15 Dec 2022 17:52:41 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=A4=9A=E4=B8=AAapi?= =?UTF-8?q?-key=E5=B9=B6=E6=94=AF=E6=8C=81=E8=AE=BE=E7=BD=AE=E6=96=87?= =?UTF-8?q?=E5=AD=97=E9=87=8F=E9=98=88=E5=80=BC=E5=AF=B9=E5=85=B6=E8=BF=9B?= =?UTF-8?q?=E8=A1=8C=E8=87=AA=E5=8A=A8=E5=88=87=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config-template.py | 15 ++++++++++++- main.py | 9 ++++---- pkg/database/manager.py | 49 +++++++++++++++++++++++++++++++++++++++++ pkg/openai/manager.py | 15 ++++++++++--- 4 files changed, 80 insertions(+), 8 deletions(-) diff --git a/config-template.py b/config-template.py index 3820459a..62f5783d 100644 --- a/config-template.py +++ b/config-template.py @@ -16,10 +16,23 @@ mirai_http_api_config = { # [必需] OpenAI的配置 # api_key: OpenAI的API Key +# 若只有一个api-key,请直接修改以下内容中的"openai_api_key"为你的api-key +# 如准备了多个api-key,可以以字典的形式填写,程序会自动选择可用的api-key +# 例如{ +# "api0": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", +# "api1": "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", +# } openai_config = { - "api_key": "openai_api_key", + "api_key": { + "default": "openai_api_key" + }, } +# 单个api-key的使用量警告阈值 +# 当使用此api-key进行请求的文字量达到此阈值时,会在控制台输出警告并通知管理员 +# 若之后还有未使用超过此值的api-key,则会切换到新的api-key进行请求 +api_key_usage_threshold = 895000 + # 管理员QQ号,用于接收报错等通知,为0时不发送通知 admin_qq = 0 diff --git a/main.py b/main.py index 38e1d5e9..988626fc 100644 --- a/main.py +++ b/main.py @@ -8,8 +8,8 @@ import logging import colorlog import sys -sys.path.append(".") +sys.path.append(".") log_colors_config = { 'DEBUG': 'green', # cyan white @@ -55,12 +55,12 @@ def main(): import pkg.qqbot.manager # 主启动流程 - openai_interact = pkg.openai.manager.OpenAIInteract(config.openai_config['api_key'], config.completion_api_params) - database = pkg.database.manager.DatabaseManager() database.initialize_database() + openai_interact = pkg.openai.manager.OpenAIInteract(config.openai_config['api_key'], config.completion_api_params) + # 加载所有未超时的session pkg.openai.session.load_sessions() @@ -78,6 +78,7 @@ def main(): time.sleep(86400) except KeyboardInterrupt: try: + pkg.openai.manager.get_inst().key_mgr.dump_usage() for session in pkg.openai.session.sessions: logging.info('持久化session: %s', session) pkg.openai.session.sessions[session].persistence() @@ -85,7 +86,7 @@ def main(): if not isinstance(e, KeyboardInterrupt): raise e print("程序退出") - break + sys.exit(0) if __name__ == '__main__': diff --git a/pkg/database/manager.py b/pkg/database/manager.py index 30e508f5..8b04218c 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -1,3 +1,4 @@ +import hashlib import logging import time from sqlite3 import Cursor @@ -48,6 +49,15 @@ class DatabaseManager: `prompt` text not null ) """) + + self.execute(""" + create table if not exists `api_key_usage`( + `id` INTEGER PRIMARY KEY AUTOINCREMENT, + `key_md5` varchar(255) not null, + `timestamp` bigint not null, + `usage` bigint not null + ) + """) print('Database initialized.') # session持久化 @@ -214,6 +224,45 @@ class DatabaseManager: return sessions + # 将apikey的使用量存进数据库 + def dump_api_key_usage(self, api_keys: dict, usage: dict): + logging.debug('dumping api key usage...') + logging.debug(api_keys) + logging.debug(usage) + for api_key in api_keys: + # 计算key的md5值 + key_md5 = hashlib.md5(api_keys[api_key].encode('utf-8')).hexdigest() + # 获取使用量 + usage_count = 0 + if key_md5 in usage: + usage_count = usage[key_md5] + # 将使用量存进数据库 + # 先检查是否已存在 + self.execute(""" + select count(*) from `api_key_usage` where `key_md5` = '{}'""".format(key_md5)) + result = self.cursor.fetchone() + if result[0] == 0: + # 不存在则插入 + self.execute(""" + insert into `api_key_usage` (`key_md5`, `usage`,`timestamp`) values ('{}', {}, {}) + """.format(key_md5, usage_count, int(time.time()))) + else: + # 存在则更新,timestamp设置为当前 + self.execute(""" + update `api_key_usage` set `usage` = {}, `timestamp` = {} where `key_md5` = '{}' + """.format(usage_count, int(time.time()), key_md5)) + + def load_api_key_usage(self): + self.execute(""" + select `key_md5`, `usage` from `api_key_usage` + """) + results = self.cursor.fetchall() + usage = {} + for result in results: + key_md5 = result[0] + usage_count = result[1] + usage[key_md5] = usage_count + return usage def get_inst() -> DatabaseManager: global inst diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index fed9fd2c..3e9daec0 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -2,19 +2,24 @@ import openai import config +import pkg.openai.keymgr + inst = None # 为其他模块提供与OpenAI交互的接口 class OpenAIInteract: - api_key = '' api_params = {} + key_mgr = None + def __init__(self, api_key: str, api_params: dict): - self.api_key = api_key + # self.api_key = api_key self.api_params = api_params - openai.api_key = self.api_key + self.key_mgr = pkg.openai.keymgr.KeysManager(api_key) + + openai.api_key = self.key_mgr.get_using_key() global inst inst = self @@ -27,6 +32,10 @@ class OpenAIInteract: timeout=config.process_message_timeout, **self.api_params ) + switched = self.key_mgr.report_usage(prompt + response['choices'][0]['text']) + if switched: + openai.api_key = self.key_mgr.get_using_key() + return response