From e06a2535dd913547a460087225385acca00665e4 Mon Sep 17 00:00:00 2001 From: Rock Chin <1010553892@qq.com> Date: Tue, 3 Jan 2023 17:50:13 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E9=87=8F=E7=BB=9F=E8=AE=A1=E5=8A=9F=E8=83=BD=EF=BC=8C=E4=B8=8E?= =?UTF-8?q?key=E5=88=87=E6=8D=A2=E8=84=B1=E9=92=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 1 + pkg/audit/__init__.py | 0 pkg/audit/gatherer.py | 77 +++++++++++++++++++++++++++++++++++++++++ pkg/database/manager.py | 33 ++++++++++++++++++ pkg/openai/keymgr.py | 3 ++ pkg/openai/manager.py | 11 +++++- pkg/qqbot/process.py | 15 ++++++-- pkg/utils/context.py | 4 +++ 8 files changed, 141 insertions(+), 3 deletions(-) create mode 100644 pkg/audit/__init__.py create mode 100644 pkg/audit/gatherer.py diff --git a/main.py b/main.py index 9805b984..525a4389 100644 --- a/main.py +++ b/main.py @@ -94,6 +94,7 @@ def main(first_time_init=False): "请查看 https://github.com/RockChinQ/QChatGPT/issues/5") logging.info("如报错 \"server rejected WebSocket connection: HTTP 404\" ," "请查看 https://github.com/RockChinQ/QChatGPT/issues/22") + logging.info("其他异常请前往仓库issue搜索或提issue") else: logging.info('热重载完成') diff --git a/pkg/audit/__init__.py b/pkg/audit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/audit/gatherer.py b/pkg/audit/gatherer.py new file mode 100644 index 00000000..292f2c16 --- /dev/null +++ b/pkg/audit/gatherer.py @@ -0,0 +1,77 @@ +import hashlib +import json + +import pkg.utils.context + + +class DataGatherer: + """数据收集器""" + usage = {} + """以key值md5为key,{ + "text": { + "text-davinci-003": 文字量:int, + }, + "image": { + "256x256": 图片数量:int, + } + }为值的字典""" + + def __init__(self): + self.load_from_db() + + def report_text_model_usage(self, model, text): + key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5() + + if key_md5 not in self.usage: + self.usage[key_md5] = {} + + if "text" not in self.usage[key_md5]: + self.usage[key_md5]["text"] = {} + + if model not in self.usage[key_md5]["text"]: + self.usage[key_md5]["text"][model] = 0 + + length = ((len(text.encode('utf-8')) - len(text)) / 2 + len(text)) + self.usage[key_md5]["text"][model] += length + self.dump_to_db() + + def report_image_model_usage(self, size): + key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5() + + if key_md5 not in self.usage: + self.usage[key_md5] = {} + + if "image" not in self.usage[key_md5]: + self.usage[key_md5]["image"] = {} + + if size not in self.usage[key_md5]["image"]: + self.usage[key_md5]["image"][size] = 0 + + self.usage[key_md5]["image"][size] += 1 + self.dump_to_db() + + def get_text_length_of_key(self, key): + key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest() + if key_md5 not in self.usage: + return 0 + if "text" not in self.usage[key_md5]: + return 0 + # 遍历其中所有模型,求和 + return sum(self.usage[key_md5]["text"].values()) + + def get_image_count_of_key(self, key): + key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest() + if key_md5 not in self.usage: + return 0 + if "image" not in self.usage[key_md5]: + return 0 + # 遍历其中所有模型,求和 + return sum(self.usage[key_md5]["image"].values()) + + def dump_to_db(self): + pkg.utils.context.get_database_manager().dump_usage_json(self.usage) + + def load_from_db(self): + json_str = pkg.utils.context.get_database_manager().load_usage_json() + if json_str is not None: + self.usage = json.loads(json_str) diff --git a/pkg/database/manager.py b/pkg/database/manager.py index 5c9b700b..048df1d7 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -1,4 +1,5 @@ import hashlib +import json import logging import time from sqlite3 import Cursor @@ -65,6 +66,13 @@ class DatabaseManager: `fee` DECIMAL(12,6) not null ) """) + + self.execute(""" + create table if not exists `account_usage`( + `id` INTEGER PRIMARY KEY AUTOINCREMENT, + `json` text not null + ) + """) print('Database initialized.') # session持久化 @@ -310,3 +318,28 @@ class DatabaseManager: fee[key_md5] = fee_count return fee + def dump_usage_json(self, usage: dict): + json_str = json.dumps(usage) + self.execute(""" + select count(*) from `account_usage`""") + result = self.cursor.fetchone() + if result[0] == 0: + # 不存在则插入 + self.execute(""" + insert into `account_usage` (`json`) values ('{}') + """.format(json_str)) + else: + # 存在则更新 + self.execute(""" + update `account_usage` set `json` = '{}' where `id` = 1 + """.format(json_str)) + + def load_usage_json(self): + self.execute(""" + select `json` from `account_usage` order by id desc limit 1 + """) + result = self.cursor.fetchone() + if result is None: + return None + else: + return result[0] diff --git a/pkg/openai/keymgr.py b/pkg/openai/keymgr.py index 03ce5c3f..1e8ebae2 100644 --- a/pkg/openai/keymgr.py +++ b/pkg/openai/keymgr.py @@ -28,6 +28,9 @@ class KeysManager: def get_using_key(self): return self.using_key + def get_using_key_md5(self): + return hashlib.md5(self.using_key.encode('utf-8')).hexdigest() + def __init__(self, api_key): # if hasattr(config, 'api_key_usage_threshold'): # self.api_key_usage_threshold = config.api_key_usage_threshold diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index 2d0c1575..607b73b6 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -7,13 +7,16 @@ import config import pkg.openai.keymgr import pkg.openai.pricing as pricing import pkg.utils.context +import pkg.audit.gatherer # 为其他模块提供与OpenAI交互的接口 class OpenAIInteract: api_params = {} - key_mgr = None + key_mgr: pkg.openai.keymgr.KeysManager = None + + audit_mgr: pkg.audit.gatherer.DataGatherer = None default_image_api_params = { "size": "256x256", @@ -23,6 +26,7 @@ class OpenAIInteract: # self.api_key = api_key self.key_mgr = pkg.openai.keymgr.KeysManager(api_key) + self.audit_mgr = pkg.audit.gatherer.DataGatherer() openai.api_key = self.key_mgr.get_using_key() @@ -37,6 +41,9 @@ class OpenAIInteract: **config.completion_api_params ) + self.audit_mgr.report_text_model_usage(config.completion_api_params['model'], + prompt + response['choices'][0]['text']) + switched = self.key_mgr.report_fee(pricing.language_base_price(config.completion_api_params['model'], prompt + response['choices'][0]['text'])) if switched: @@ -54,6 +61,8 @@ class OpenAIInteract: **params ) + self.audit_mgr.report_image_model_usage(params['size']) + switched = self.key_mgr.report_fee(pricing.image_price(params['size'])) if switched: diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index 5074b4ce..99842582 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -124,9 +124,9 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes reply_str += ",当前处于全新会话或不在此页" reply = [reply_str] - elif cmd == 'usage': + elif cmd == 'fee': api_keys = pkg.utils.context.get_openai_manager().key_mgr.api_key - reply_str = "[bot]api-key使用情况:(阈值:{})\n\n".format( + reply_str = "[bot]api-key费用情况(估算):(阈值:{})\n\n".format( pkg.utils.context.get_openai_manager().key_mgr.api_key_fee_threshold) using_key_name = "" @@ -145,7 +145,18 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes reply_str += "\n当前使用:{}".format(using_key_name) reply = [reply_str] + elif cmd == 'usage': + reply_str = "[bot]各api-key使用情况:" + api_keys = pkg.utils.context.get_openai_manager().key_mgr.api_key + for key_name in api_keys: + text_length = pkg.utils.context.get_openai_manager().audit_mgr\ + .get_text_length_of_key(api_keys[key_name]) + image_count = pkg.utils.context.get_openai_manager().audit_mgr\ + .get_image_count_of_key(api_keys[key_name]) + reply_str += "{}:\n - 文本长度:{}\n - 图片数量:{}\n".format(key_name, int(text_length), int(image_count)) + + reply = [reply_str] elif cmd == 'draw': if len(params) == 0: reply = ["[bot]err:请输入图片描述文字"] diff --git a/pkg/utils/context.py b/pkg/utils/context.py index 2576609b..449168eb 100644 --- a/pkg/utils/context.py +++ b/pkg/utils/context.py @@ -1,3 +1,7 @@ +import pkg.database.manager +import pkg.openai.manager +import pkg.qqbot.manager + context = { 'inst': { 'database.manager.DatabaseManager': None,