feat: 使用费用估算替代字数额度估算 #47

This commit is contained in:
Rock Chin
2022-12-28 00:05:25 +08:00
parent 7b5d47a2ca
commit 7ed558056f
7 changed files with 171 additions and 46 deletions

View File

@@ -52,10 +52,10 @@ response_rules = {
"regexp": [] # "为什么.*", "怎么?样.*", "怎么.*", "如何.*", "[Hh]ow to.*", "[Ww]hy not.*", "[Ww]hat is.*", ".*怎么办", ".*咋办" "regexp": [] # "为什么.*", "怎么?样.*", "怎么.*", "如何.*", "[Hh]ow to.*", "[Ww]hy not.*", "[Ww]hat is.*", ".*怎么办", ".*咋办"
} }
# 单个api-key的使用量警告阈值 # 单个api-key的费用警告阈值
# 当使用此api-key进行请求的文字量达到此阈值时,会在控制台输出警告并通知管理员 # 当使用此api-key进行请求所消耗的费用估算达到此阈值时,会在控制台输出警告并通知管理员
# 若之后还有未使用超过此值的api-key则会切换到新的api-key进行请求 # 若之后还有未使用超过此值的api-key则会切换到新的api-key进行请求
api_key_usage_threshold = 900000 api_key_fee_threshold = 18.0
# 敏感词过滤开关,以同样数量的*代替敏感词回复 # 敏感词过滤开关,以同样数量的*代替敏感词回复
# 请在sensitive.json中添加敏感词 # 请在sensitive.json中添加敏感词
@@ -80,7 +80,7 @@ completion_api_params = {
# OpenAI的Image API的参数 # OpenAI的Image API的参数
# 具体请查看OpenAI的文档: https://beta.openai.com/docs/api-reference/images/create # 具体请查看OpenAI的文档: https://beta.openai.com/docs/api-reference/images/create
image_api_params = { image_api_params = {
"size": "256x256", "size": "256x256", # 图片尺寸支持256x256, 512x512, 1024x1024
} }
# 消息处理的超时时间,单位为秒 # 消息处理的超时时间,单位为秒

View File

@@ -78,7 +78,7 @@ def main():
time.sleep(86400) time.sleep(86400)
except KeyboardInterrupt: except KeyboardInterrupt:
try: try:
pkg.openai.manager.get_inst().key_mgr.dump_usage() pkg.openai.manager.get_inst().key_mgr.dump_fee()
for session in pkg.openai.session.sessions: for session in pkg.openai.session.sessions:
logging.info('持久化session: %s', session) logging.info('持久化session: %s', session)
pkg.openai.session.sessions[session].persistence() pkg.openai.session.sessions[session].persistence()

View File

@@ -58,6 +58,15 @@ class DatabaseManager:
`usage` bigint not null `usage` bigint not null
) )
""") """)
self.execute("""
create table if not exists `account_fee`(
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
`key_md5` varchar(255) not null,
`timestamp` bigint not null,
`fee` DECIMAL(9,3) not null
)
""")
print('Database initialized.') print('Database initialized.')
# session持久化 # session持久化
@@ -264,6 +273,45 @@ class DatabaseManager:
usage[key_md5] = usage_count usage[key_md5] = usage_count
return usage return usage
def dump_api_key_fee(self, api_keys: dict, fee: dict):
logging.debug("dumping api key fee...")
logging.debug(api_keys)
logging.debug(fee)
for api_key in api_keys:
# 计算key的md5值
key_md5 = hashlib.md5(api_keys[api_key].encode('utf-8')).hexdigest()
# 获取使用量
fee_count = 0
if key_md5 in fee:
fee_count = fee[key_md5]
# 将使用量存进数据库
# 先检查是否已存在
self.execute("""
select count(*) from `account_fee` where `key_md5` = '{}'""".format(key_md5))
result = self.cursor.fetchone()
if result[0] == 0:
# 不存在则插入
self.execute("""
insert into `account_fee` (`key_md5`, `fee`,`timestamp`) values ('{}', {}, {})
""".format(key_md5, fee_count, int(time.time())))
else:
# 存在则更新timestamp设置为当前
self.execute("""
update `account_fee` set `fee` = {}, `timestamp` = {} where `key_md5` = '{}'
""".format(fee_count, int(time.time()), key_md5))
def load_api_key_fee(self):
self.execute("""
select `key_md5`, `fee` from `account_fee`
""")
results = self.cursor.fetchall()
fee = {}
for result in results:
key_md5 = result[0]
fee_count = result[1]
fee[key_md5] = fee_count
return fee
def get_inst() -> DatabaseManager: def get_inst() -> DatabaseManager:
global inst global inst
return inst return inst

View File

@@ -14,7 +14,11 @@ class KeysManager:
# 其中键为api-key的md5值值为使用量 # 其中键为api-key的md5值值为使用量
usage = {} usage = {}
api_key_usage_threshold = 900000 fee = {}
api_key_usage_threshold = 900000 # 已弃用
api_key_fee_threshold = 18.0
using_key = "" using_key = ""
@@ -24,9 +28,11 @@ class KeysManager:
return self.using_key return self.using_key
def __init__(self, api_key): def __init__(self, api_key):
if hasattr(config, 'api_key_usage_threshold'): # if hasattr(config, 'api_key_usage_threshold'):
self.api_key_usage_threshold = config.api_key_usage_threshold # self.api_key_usage_threshold = config.api_key_usage_threshold
self.load_usage() if hasattr(config, 'api_key_fee_threshold'):
self.api_key_fee_threshold = config.api_key_fee_threshold
self.load_fee()
if type(api_key) is dict: if type(api_key) is dict:
self.api_key = api_key self.api_key = api_key
@@ -45,9 +51,9 @@ class KeysManager:
# 根据使用量自动切换到可用的api-key # 根据使用量自动切换到可用的api-key
# 返回是否切换成功, 切换后的api-key的别名 # 返回是否切换成功, 切换后的api-key的别名
def auto_switch(self) -> (bool, str): def auto_switch(self) -> (bool, str):
self.dump_usage() self.dump_fee()
for key_name in self.api_key: for key_name in self.api_key:
if self.get_usage(self.api_key[key_name]) < self.api_key_usage_threshold: if self.get_fee(self.api_key[key_name]) < self.api_key_fee_threshold:
self.using_key = self.api_key[key_name] self.using_key = self.api_key[key_name]
logging.info("使用api-key:" + key_name) logging.info("使用api-key:" + key_name)
return True, key_name return True, key_name
@@ -57,30 +63,76 @@ class KeysManager:
return False, "" return False, ""
def get_usage(self, api_key):
md5 = hashlib.md5(api_key.encode('utf-8')).hexdigest()
if md5 not in self.usage:
self.usage[md5] = 0
return self.usage[md5]
def add(self, key_name, key): def add(self, key_name, key):
self.api_key[key_name] = key self.api_key[key_name] = key
# def get_usage(self, api_key):
# md5 = hashlib.md5(api_key.encode('utf-8')).hexdigest()
# if md5 not in self.usage:
# self.usage[md5] = 0
# return self.usage[md5]
# 报告使用 # 报告使用
# 返回是否需要将openai的api-key切换 # 返回是否需要将openai的api-key切换
def report_usage(self, new_content: str) -> bool: # def report_usage(self, new_content: str) -> bool:
# md5 = hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
# if md5 not in self.usage:
# self.usage[md5] = 0
#
# # 经测算得出的理论与实际的偏差比例
# salt_rate = 0.91
#
# self.usage[md5] += ( (len(new_content.encode('utf-8')) - len(new_content)) / 2 + len(new_content) )*salt_rate
#
# self.usage[md5] = int(self.usage[md5])
#
# if self.usage[md5] >= self.api_key_usage_threshold:
# switch_result, key_name = self.auto_switch()
#
# # 检查是否切换到新的
# if switch_result:
# if key_name not in self.alerted:
# # 通知管理员
# pkg.qqbot.manager.get_inst().notify_admin("api-key已切换到:" + key_name)
# self.alerted.append(key_name)
# return True
# else:
# if key_name not in self.alerted:
# # 通知管理员
# pkg.qqbot.manager.get_inst().notify_admin("api-key已用完无未使用的api-key可供切换")
# self.alerted.append(key_name)
# return False
# 设置当前使用的api-key使用量超限
# 这是在尝试调用api时发生超限异常时调用的
def set_current_exceeded(self):
md5 = hashlib.md5(self.using_key.encode('utf-8')).hexdigest() md5 = hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
if md5 not in self.usage: # self.usage[md5] = self.api_key_usage_threshold
self.usage[md5] = 0 self.fee[md5] = self.api_key_fee_threshold
self.dump_fee()
# 经测算得出的理论与实际的偏差比例 # def dump_usage(self):
salt_rate = 0.91 # pkg.database.manager.get_inst().dump_api_key_usage(api_keys=self.api_key, usage=self.usage)
self.usage[md5] += ( (len(new_content.encode('utf-8')) - len(new_content)) / 2 + len(new_content) )*salt_rate # def load_usage(self):
# self.usage = pkg.database.manager.get_inst().load_api_key_usage()
# logging.debug("load usage:" + str(self.usage))
# print("load usage:" + str(self.usage))
self.usage[md5] = int(self.usage[md5]) def get_fee(self, api_key):
md5 = hashlib.md5(api_key.encode('utf-8')).hexdigest()
if md5 not in self.fee:
self.fee[md5] = 0
return self.fee[md5]
if self.usage[md5] >= self.api_key_usage_threshold: def report_fee(self, fee: float) -> bool:
md5 = hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
if md5 not in self.fee:
self.fee[md5] = 0
self.fee[md5] += fee
if self.fee[md5] >= self.api_key_fee_threshold:
switch_result, key_name = self.auto_switch() switch_result, key_name = self.auto_switch()
# 检查是否切换到新的 # 检查是否切换到新的
@@ -97,17 +149,9 @@ class KeysManager:
self.alerted.append(key_name) self.alerted.append(key_name)
return False return False
# 设置当前使用的api-key使用量超限 def dump_fee(self):
# 这是在尝试调用api时发生超限异常时调用的 pkg.database.manager.get_inst().dump_api_key_fee(api_keys=self.api_key, fee=self.fee)
def set_current_exceeded(self):
md5 = hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
self.usage[md5] = self.api_key_usage_threshold
self.dump_usage()
def dump_usage(self): def load_fee(self):
pkg.database.manager.get_inst().dump_api_key_usage(api_keys=self.api_key, usage=self.usage) self.fee = pkg.database.manager.get_inst().load_api_key_fee()
logging.info("load fee:" + str(self.fee))
def load_usage(self):
self.usage = pkg.database.manager.get_inst().load_api_key_usage()
logging.debug("load usage:" + str(self.usage))
print("load usage:" + str(self.usage))

View File

@@ -5,6 +5,7 @@ import openai
import config import config
import pkg.openai.keymgr import pkg.openai.keymgr
import pkg.openai.pricing as pricing
inst = None inst = None
@@ -37,19 +38,30 @@ class OpenAIInteract:
timeout=config.process_message_timeout, timeout=config.process_message_timeout,
**config.completion_api_params **config.completion_api_params
) )
switched = self.key_mgr.report_usage(prompt + response['choices'][0]['text'])
switched = self.key_mgr.report_fee(pricing.language_base_price(config.completion_api_params['model'],
prompt + response['choices'][0]['text']))
if switched: if switched:
openai.api_key = self.key_mgr.get_using_key() openai.api_key = self.key_mgr.get_using_key()
return response return response
def request_image(self, prompt): def request_image(self, prompt):
params = config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params
response = openai.Image.create( response = openai.Image.create(
prompt=prompt, prompt=prompt,
n=1, n=1,
**config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params **params
) )
switched = self.key_mgr.report_fee(pricing.image_price(params['size']))
if switched:
openai.api_key = self.key_mgr.get_using_key()
return response return response

21
pkg/openai/pricing.py Normal file
View File

@@ -0,0 +1,21 @@
pricing = {
"base": { # 文字模型单位是1000字符
"text-davinci-003": 0.02,
},
"image": {
"256x256": 0.016,
"512x512": 0.018,
"1024x1024": 0.02,
}
}
def language_base_price(model, text):
salt_rate = 0.93
length = ((len(text.encode('utf-8')) - len(text)) / 2 + len(text)) * salt_rate
return pricing["base"][model] * length / 1000
def image_price(size):
return pricing["image"][size]

View File

@@ -113,17 +113,17 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str) ->
elif cmd == 'usage': elif cmd == 'usage':
api_keys = pkg.openai.manager.get_inst().key_mgr.api_key api_keys = pkg.openai.manager.get_inst().key_mgr.api_key
reply_str = "[bot]api-key使用情况:(阈值:{})\n\n".format( reply_str = "[bot]api-key使用情况:(阈值:{})\n\n".format(
pkg.openai.manager.get_inst().key_mgr.api_key_usage_threshold) pkg.openai.manager.get_inst().key_mgr.api_key_fee_threshold)
using_key_name = "" using_key_name = ""
for api_key in api_keys: for api_key in api_keys:
reply_str += "{}:\n - {} {}%\n".format(api_key, reply_str += "{}:\n - {} {}%\n".format(api_key,
pkg.openai.manager.get_inst().key_mgr.get_usage( pkg.openai.manager.get_inst().key_mgr.get_fee(
api_keys[api_key]), api_keys[api_key]),
round( round(
pkg.openai.manager.get_inst().key_mgr.get_usage( pkg.openai.manager.get_inst().key_mgr.get_fee(
api_keys[ api_keys[
api_key]) / pkg.openai.manager.get_inst().key_mgr.api_key_usage_threshold * 100, api_key]) / pkg.openai.manager.get_inst().key_mgr.api_key_fee_threshold * 100,
3)) 3))
if api_keys[api_key] == pkg.openai.manager.get_inst().key_mgr.using_key: if api_keys[api_key] == pkg.openai.manager.get_inst().key_mgr.using_key:
using_key_name = api_key using_key_name = api_key
@@ -158,7 +158,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str) ->
reply = ["[bot]err:调用API失败请重试或联系作者或等待修复"] reply = ["[bot]err:调用API失败请重试或联系作者或等待修复"]
except openai.error.RateLimitError as e: except openai.error.RateLimitError as e:
# 尝试切换api-key # 尝试切换api-key
current_tokens_amt = pkg.openai.manager.get_inst().key_mgr.get_usage( current_tokens_amt = pkg.openai.manager.get_inst().key_mgr.get_fee(
pkg.openai.manager.get_inst().key_mgr.get_using_key()) pkg.openai.manager.get_inst().key_mgr.get_using_key())
pkg.openai.manager.get_inst().key_mgr.set_current_exceeded() pkg.openai.manager.get_inst().key_mgr.set_current_exceeded()
switched, name = pkg.openai.manager.get_inst().key_mgr.auto_switch() switched, name = pkg.openai.manager.get_inst().key_mgr.auto_switch()