Compare commits

...

31 Commits

Author SHA1 Message Date
Rock Chin
a1bfbad24e Release v2.1.3 2023-03-06 12:41:35 +08:00
Rock Chin
8af4918048 Merge pull request #230 from LINSTCL/config_integrity_check
添加配置文件完整性校验
2023-03-06 12:35:59 +08:00
Rock Chin
49f4ab0ec8 perf: 完整性检查忽略__开头的属性 2023-03-06 12:34:08 +08:00
LINSTCL
85c623fb0f 修改提示逻辑 2023-03-06 11:27:16 +08:00
Rock Chin
9e28298250 perf: 完善未启动情况下的自动更新 2023-03-06 11:18:31 +08:00
Rock Chin
7a04ef0985 feat: 未启动状态下的自动更新 (#223) 2023-03-06 11:04:25 +08:00
LINSTCL
83005e9ba9 添加配置文件完整性校验 2023-03-06 09:40:33 +08:00
Rock Chin
f0c78f0529 Merge pull request #222 from LINSTCL/threadpool-optimization
使用线程池控制线程数量,防止高并发崩溃
2023-03-06 08:51:47 +08:00
Rock Chin
3f638adcf9 perf(qqbot/manager.py): 优化控制台日志显示 2023-03-06 08:50:28 +08:00
Rock Chin
d9405d8d5d fix: main.py的字段版本兼容性问题 2023-03-06 08:48:50 +08:00
Rock Chin
606713a418 Merge pull request #228 from yichuxue/patch-1
启动时,更新openai和pillow库超时问题
2023-03-06 08:44:29 +08:00
Rock Chin
52102f0d0a feat(deps): trusted-host参数 2023-03-06 08:43:51 +08:00
Rock Chin
61c29829ed Release v2.1.2 2023-03-06 08:35:04 +08:00
依初雪
df30931aad 启动openai和pillow库超时问题
主要改动如下:
1、在ensure_dependencies函数更更新包时,出现超时的情况,指定更新源 https://pypi.douban.com/simple/
2023-03-06 00:32:46 +08:00
Rock Chin
5afcc03e8b fix: 错误的!version指令处理逻辑 2023-03-05 20:07:08 +08:00
Rock Chin
fbeb4673f4 Merge pull request #226 from RockChinQ/text2img-perf
[Feat] 不再自带字体文件
2023-03-05 19:59:16 +08:00
Rock Chin
4aba319560 fix: 错误的加载过程 2023-03-05 19:57:39 +08:00
Rock Chin
74f79e002c perf: 优化字体加载过程 2023-03-05 19:54:51 +08:00
Rock Chin
2668ef2b3f feat: 不再自带字体文件 2023-03-05 19:36:09 +08:00
Rock Chin
74c018e271 Merge pull request #225 from RockChinQ/fix-switch-exce
[Fix] 修复插件开关问题
2023-03-05 17:36:03 +08:00
Rock Chin
64776fd601 doc: OpenAI注册教程链接 2023-03-05 16:47:42 +08:00
LINSTCL
59877bf71d 添加日志输出 2023-03-05 16:47:07 +08:00
LINSTCL
d2800ac58b 使用线程池控制线程数量,防止高并发崩溃 2023-03-05 16:41:12 +08:00
Rock Chin
ffef944119 fix: 热重载后插件开关状态被重置 (#177) 2023-03-05 16:04:45 +08:00
Rock Chin
651b291ef6 doc: 添加部分注释 2023-03-05 15:39:13 +08:00
Rock Chin
e4b581f197 doc: 致谢添加贡献者 2023-03-05 14:37:14 +08:00
Rock Chin
4f3939e2d9 Merge pull request #219 from LINSTCL/modelmgr_optimization
优化模型接口底层的异常处理
2023-03-05 14:18:24 +08:00
LINSTCL
1048ca612d 补充错误情况 2023-03-05 14:06:07 +08:00
LINSTCL
b1a2d21ee9 优化异常处理 2023-03-05 13:52:43 +08:00
Rock Chin
dd4e8bdc8b perf: 优化版本识别逻辑 2023-03-05 12:26:51 +08:00
Rock Chin
e28c9bae0c feat: 修改上报功能识别版本的逻辑 2023-03-05 12:21:28 +08:00
22 changed files with 314 additions and 141 deletions

View File

@@ -115,10 +115,9 @@
### - 注册OpenAI账号 ### - 注册OpenAI账号
**可以直接进群找群主购买** 参考以下文章自行注册
或参考以下文章自行注册
> ~~[只需 1 元搞定 ChatGPT 注册](https://zhuanlan.zhihu.com/p/589470082)~~(已失效) > [国内注册ChatGPT的方法(100%可用)](https://www.pythonthree.com/register-openai-chatgpt/)
> [手把手教你如何注册ChatGPT超级详细](https://guxiaobei.com/51461) > [手把手教你如何注册ChatGPT超级详细](https://guxiaobei.com/51461)
注册成功后请前往[个人中心查看](https://beta.openai.com/account/api-keys)api_key 注册成功后请前往[个人中心查看](https://beta.openai.com/account/api-keys)api_key
@@ -227,6 +226,7 @@ python3 main.py
- [@dominoar](https://github.com/dominoar) 为本项目开发多种插件 - [@dominoar](https://github.com/dominoar) 为本项目开发多种插件
- [@hissincn](https://github.com/hissincn) 本项目贡献者 - [@hissincn](https://github.com/hissincn) 本项目贡献者
- [@LINSTCL](https://github.com/LINSTCL) GPT-3.5官方模型适配贡献者 - [@LINSTCL](https://github.com/LINSTCL) GPT-3.5官方模型适配贡献者
- [@Haibersut](https://github.com/Haibersut) 本项目贡献者
以及其他所有为本项目提供支持的朋友们。 以及其他所有为本项目提供支持的朋友们。

View File

@@ -183,6 +183,12 @@ blob_message_threshold = 256
# - "forward": 将长消息转换为转发消息组件发送 # - "forward": 将长消息转换为转发消息组件发送
blob_message_strategy = "forward" blob_message_strategy = "forward"
# 文字转图片时使用的字体文件路径
# 当策略为"image"时生效
# 若在Windows系统下程序会自动使用Windows自带的微软雅黑字体
# 若未填写或不存在且不是Windows将禁用文字转图片功能改为使用转发消息组件
font_path = ""
# 消息处理超时重试次数 # 消息处理超时重试次数
retry_times = 3 retry_times = 3
@@ -196,6 +202,11 @@ hide_exce_info_to_user = False
# 设置为空字符串时,不发送提示信息 # 设置为空字符串时,不发送提示信息
alter_tip_message = '出错了,请稍后再试' alter_tip_message = '出错了,请稍后再试'
# 机器人线程池大小
# 该参数决定机器人可以同时处理几个人的消息,超出线程池数量的请求会被阻塞,不会被丢弃
# 如果你不清楚该参数的意义,请不要更改
pool_num = 10
# 每个会话的过期时间,单位为秒 # 每个会话的过期时间,单位为秒
# 默认值20分钟 # 默认值20分钟
session_expire_time = 60 * 20 session_expire_time = 60 * 20

41
main.py
View File

@@ -45,7 +45,9 @@ def init_db():
def ensure_dependencies(): def ensure_dependencies():
import pkg.utils.pkgmgr as pkgmgr import pkg.utils.pkgmgr as pkgmgr
pkgmgr.run_pip(["install", "openai", "Pillow", "--upgrade"]) pkgmgr.run_pip(["install", "openai", "Pillow", "--upgrade",
"-i", "https://pypi.douban.com/simple/",
"--trusted-host", "pypi.douban.com"])
known_exception_caught = False known_exception_caught = False
@@ -127,13 +129,26 @@ def main(first_time_init=False):
config = importlib.import_module('config') config = importlib.import_module('config')
import pkg.utils.context
pkg.utils.context.set_config(config)
init_runtime_log_file() init_runtime_log_file()
sh = reset_logging() sh = reset_logging()
# 配置完整性校验
is_integrity = True
config_template = importlib.import_module('config-template')
for key in dir(config_template):
if not key.startswith("__") and not hasattr(config, key):
setattr(config, key, getattr(config_template, key))
logging.warning("[{}]不存在".format(key))
is_integrity = False
if not is_integrity:
logging.warning("配置文件不完整请依据config-template.py检查config.py")
logging.warning("以上配置已被设为默认值将在5秒后继续启动... ")
time.sleep(5)
import pkg.utils.context
pkg.utils.context.set_config(config)
# 检查是否设置了管理员 # 检查是否设置了管理员
if not (hasattr(config, 'admin_qq') and config.admin_qq != 0): if not (hasattr(config, 'admin_qq') and config.admin_qq != 0):
# logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段") # logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
@@ -180,7 +195,7 @@ def main(first_time_init=False):
# 初始化qq机器人 # 初始化qq机器人
qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config, qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config,
timeout=config.process_message_timeout, retry=config.retry_times, timeout=config.process_message_timeout, retry=config.retry_times,
first_time_init=first_time_init) first_time_init=first_time_init, pool_num=config.pool_num)
# 加载插件 # 加载插件
import pkg.plugin.host import pkg.plugin.host
@@ -340,19 +355,9 @@ if __name__ == '__main__':
sys.exit(0) sys.exit(0)
elif len(sys.argv) > 1 and sys.argv[1] == 'update': elif len(sys.argv) > 1 and sys.argv[1] == 'update':
try: print("正在进行程序更新...")
try: import pkg.utils.updater as updater
import pkg.utils.pkgmgr updater.update_all(cli=True)
pkg.utils.pkgmgr.ensure_dulwich()
except:
pass
from dulwich import porcelain
repo = porcelain.open_repo('.')
porcelain.pull(repo)
except ModuleNotFoundError:
print("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77")
sys.exit(0) sys.exit(0)
# import pkg.utils.configmgr # import pkg.utils.configmgr

View File

@@ -0,0 +1,3 @@
"""
审计相关操作
"""

View File

@@ -1,3 +1,7 @@
"""
使用量统计以及数据上报功能实现
"""
import hashlib import hashlib
import json import json
import logging import logging
@@ -10,8 +14,11 @@ import pkg.utils.updater
class DataGatherer: class DataGatherer:
"""数据收集器""" """数据收集器"""
usage = {} usage = {}
"""以key值md5为key,{ """各api-key的使用量
以key值md5为key,{
"text": { "text": {
"text-davinci-003": 文字量:int, "text-davinci-003": 文字量:int,
}, },
@@ -25,11 +32,16 @@ class DataGatherer:
def __init__(self): def __init__(self):
self.load_from_db() self.load_from_db()
try: try:
self.version_str = pkg.utils.updater.get_commit_id_and_time_and_msg()[:40 if len(pkg.utils.updater.get_commit_id_and_time_and_msg()) > 40 else len(pkg.utils.updater.get_commit_id_and_time_and_msg())] self.version_str = pkg.utils.updater.get_current_tag() # 从updater模块获取版本号
except: except:
pass pass
def report_to_server(self, subservice_name: str, count: int): def report_to_server(self, subservice_name: str, count: int):
"""向中央服务器报告使用量
只会报告此次请求的使用量,不会报告总量。
不包含除版本号、使用类型、使用量以外的任何信息,仅供开发者分析使用情况。
"""
try: try:
config = pkg.utils.context.get_config() config = pkg.utils.context.get_config()
if hasattr(config, "report_usage") and not config.report_usage: if hasattr(config, "report_usage") and not config.report_usage:
@@ -44,7 +56,9 @@ class DataGatherer:
return self.usage[key_md5] if key_md5 in self.usage else {} return self.usage[key_md5] if key_md5 in self.usage else {}
def report_text_model_usage(self, model, total_tokens): def report_text_model_usage(self, model, total_tokens):
key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5() """调用方报告文字模型请求文字使用量"""
key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5() # 以key的md5进行储存
if key_md5 not in self.usage: if key_md5 not in self.usage:
self.usage[key_md5] = {} self.usage[key_md5] = {}
@@ -62,6 +76,8 @@ class DataGatherer:
self.report_to_server("text", length) self.report_to_server("text", length)
def report_image_model_usage(self, size): def report_image_model_usage(self, size):
"""调用方报告图片模型请求图片使用量"""
key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5() key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()
if key_md5 not in self.usage: if key_md5 not in self.usage:
@@ -79,6 +95,7 @@ class DataGatherer:
self.report_to_server("image", 1) self.report_to_server("image", 1)
def get_text_length_of_key(self, key): def get_text_length_of_key(self, key):
"""获取指定api-key (明文) 的文字总使用量(本地记录)"""
key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest() key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest()
if key_md5 not in self.usage: if key_md5 not in self.usage:
return 0 return 0
@@ -88,6 +105,8 @@ class DataGatherer:
return sum(self.usage[key_md5]["text"].values()) return sum(self.usage[key_md5]["text"].values())
def get_image_count_of_key(self, key): def get_image_count_of_key(self, key):
"""获取指定api-key (明文) 的图片总使用量(本地记录)"""
key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest() key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest()
if key_md5 not in self.usage: if key_md5 not in self.usage:
return 0 return 0
@@ -97,6 +116,7 @@ class DataGatherer:
return sum(self.usage[key_md5]["image"].values()) return sum(self.usage[key_md5]["image"].values())
def get_total_text_length(self): def get_total_text_length(self):
"""获取所有api-key的文字总使用量(本地记录)"""
total = 0 total = 0
for key in self.usage: for key in self.usage:
if "text" not in self.usage[key]: if "text" not in self.usage[key]:

View File

@@ -0,0 +1,3 @@
"""
数据库操作封装
"""

View File

@@ -1,3 +1,6 @@
"""
数据库管理模块
"""
import hashlib import hashlib
import json import json
import logging import logging
@@ -9,9 +12,9 @@ import sqlite3
import pkg.utils.context import pkg.utils.context
# 数据库管理
# 为其他模块提供数据库操作接口
class DatabaseManager: class DatabaseManager:
"""封装数据库底层操作,并提供方法给上层使用"""
conn = None conn = None
cursor = None cursor = None
@@ -23,13 +26,14 @@ class DatabaseManager:
# 连接到数据库文件 # 连接到数据库文件
def reconnect(self): def reconnect(self):
"""连接到数据库"""
self.conn = sqlite3.connect('database.db', check_same_thread=False) self.conn = sqlite3.connect('database.db', check_same_thread=False)
self.cursor = self.conn.cursor() self.cursor = self.conn.cursor()
def close(self): def close(self):
self.conn.close() self.conn.close()
def execute(self, *args, **kwargs) -> Cursor: def __execute__(self, *args, **kwargs) -> Cursor:
# logging.debug('SQL: {}'.format(sql)) # logging.debug('SQL: {}'.format(sql))
c = self.cursor.execute(*args, **kwargs) c = self.cursor.execute(*args, **kwargs)
self.conn.commit() self.conn.commit()
@@ -37,7 +41,9 @@ class DatabaseManager:
# 初始化数据库的函数 # 初始化数据库的函数
def initialize_database(self): def initialize_database(self):
self.execute(""" """创建数据表"""
self.__execute__("""
create table if not exists `sessions` ( create table if not exists `sessions` (
`id` INTEGER PRIMARY KEY AUTOINCREMENT, `id` INTEGER PRIMARY KEY AUTOINCREMENT,
`name` varchar(255) not null, `name` varchar(255) not null,
@@ -50,7 +56,7 @@ class DatabaseManager:
) )
""") """)
self.execute(""" self.__execute__("""
create table if not exists `account_fee`( create table if not exists `account_fee`(
`id` INTEGER PRIMARY KEY AUTOINCREMENT, `id` INTEGER PRIMARY KEY AUTOINCREMENT,
`key_md5` varchar(255) not null, `key_md5` varchar(255) not null,
@@ -59,7 +65,7 @@ class DatabaseManager:
) )
""") """)
self.execute(""" self.__execute__("""
create table if not exists `account_usage`( create table if not exists `account_usage`(
`id` INTEGER PRIMARY KEY AUTOINCREMENT, `id` INTEGER PRIMARY KEY AUTOINCREMENT,
`json` text not null `json` text not null
@@ -70,10 +76,12 @@ class DatabaseManager:
# session持久化 # session持久化
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int, def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
last_interact_timestamp: int, prompt: str): last_interact_timestamp: int, prompt: str):
"""持久化指定session"""
# 检查是否已经有了此name和create_timestamp的session # 检查是否已经有了此name和create_timestamp的session
# 如果有就更新prompt和last_interact_timestamp # 如果有就更新prompt和last_interact_timestamp
# 如果没有,就插入一条新的记录 # 如果没有,就插入一条新的记录
self.execute(""" self.__execute__("""
select count(*) from `sessions` where `type` = '{}' and `number` = {} and `create_timestamp` = {} select count(*) from `sessions` where `type` = '{}' and `number` = {} and `create_timestamp` = {}
""".format(subject_type, subject_number, create_timestamp)) """.format(subject_type, subject_number, create_timestamp))
count = self.cursor.fetchone()[0] count = self.cursor.fetchone()[0]
@@ -84,8 +92,8 @@ class DatabaseManager:
values (?, ?, ?, ?, ?, ?) values (?, ?, ?, ?, ?, ?)
""" """
self.execute(sql, self.__execute__(sql,
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp, ("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
last_interact_timestamp, prompt)) last_interact_timestamp, prompt))
else: else:
sql = """ sql = """
@@ -93,23 +101,23 @@ class DatabaseManager:
where `type` = ? and `number` = ? and `create_timestamp` = ? where `type` = ? and `number` = ? and `create_timestamp` = ?
""" """
self.execute(sql, (last_interact_timestamp, prompt, subject_type, self.__execute__(sql, (last_interact_timestamp, prompt, subject_type,
subject_number, create_timestamp)) subject_number, create_timestamp))
# 显式关闭一个session # 显式关闭一个session
def explicit_close_session(self, session_name: str, create_timestamp: int): def explicit_close_session(self, session_name: str, create_timestamp: int):
self.execute(""" self.__execute__("""
update `sessions` set `status` = 'explicitly_closed' where `name` = '{}' and `create_timestamp` = {} update `sessions` set `status` = 'explicitly_closed' where `name` = '{}' and `create_timestamp` = {}
""".format(session_name, create_timestamp)) """.format(session_name, create_timestamp))
def set_session_ongoing(self, session_name: str, create_timestamp: int): def set_session_ongoing(self, session_name: str, create_timestamp: int):
self.execute(""" self.__execute__("""
update `sessions` set `status` = 'on_going' where `name` = '{}' and `create_timestamp` = {} update `sessions` set `status` = 'on_going' where `name` = '{}' and `create_timestamp` = {}
""".format(session_name, create_timestamp)) """.format(session_name, create_timestamp))
# 设置session为过期 # 设置session为过期
def set_session_expired(self, session_name: str, create_timestamp: int): def set_session_expired(self, session_name: str, create_timestamp: int):
self.execute(""" self.__execute__("""
update `sessions` set `status` = 'expired' where `name` = '{}' and `create_timestamp` = {} update `sessions` set `status` = 'expired' where `name` = '{}' and `create_timestamp` = {}
""".format(session_name, create_timestamp)) """.format(session_name, create_timestamp))
@@ -117,7 +125,7 @@ class DatabaseManager:
def load_valid_sessions(self) -> dict: def load_valid_sessions(self) -> dict:
# 从数据库中加载所有还没过期的session # 从数据库中加载所有还没过期的session
config = pkg.utils.context.get_config() config = pkg.utils.context.get_config()
self.execute(""" self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
from `sessions` where `last_interact_timestamp` > {} from `sessions` where `last_interact_timestamp` > {}
""".format(int(time.time()) - config.session_expire_time)) """.format(int(time.time()) - config.session_expire_time))
@@ -150,7 +158,7 @@ class DatabaseManager:
# 获取此session_name前一个session的数据 # 获取此session_name前一个session的数据
def last_session(self, session_name: str, cursor_timestamp: int): def last_session(self, session_name: str, cursor_timestamp: int):
self.execute(""" self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc
limit 1 limit 1
@@ -179,7 +187,7 @@ class DatabaseManager:
# 获取此session_name后一个session的数据 # 获取此session_name后一个session的数据
def next_session(self, session_name: str, cursor_timestamp: int): def next_session(self, session_name: str, cursor_timestamp: int):
self.execute(""" self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc
limit 1 limit 1
@@ -207,7 +215,7 @@ class DatabaseManager:
# 列出与某个对象的所有对话session # 列出与某个对象的所有对话session
def list_history(self, session_name: str, capacity: int, page: int): def list_history(self, session_name: str, capacity: int, page: int):
self.execute(""" self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {} from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {}
""".format(session_name, capacity, capacity * page)) """.format(session_name, capacity, capacity * page))
@@ -246,22 +254,22 @@ class DatabaseManager:
usage_count = usage[key_md5] usage_count = usage[key_md5]
# 将使用量存进数据库 # 将使用量存进数据库
# 先检查是否已存在 # 先检查是否已存在
self.execute(""" self.__execute__("""
select count(*) from `api_key_usage` where `key_md5` = '{}'""".format(key_md5)) select count(*) from `api_key_usage` where `key_md5` = '{}'""".format(key_md5))
result = self.cursor.fetchone() result = self.cursor.fetchone()
if result[0] == 0: if result[0] == 0:
# 不存在则插入 # 不存在则插入
self.execute(""" self.__execute__("""
insert into `api_key_usage` (`key_md5`, `usage`,`timestamp`) values ('{}', {}, {}) insert into `api_key_usage` (`key_md5`, `usage`,`timestamp`) values ('{}', {}, {})
""".format(key_md5, usage_count, int(time.time()))) """.format(key_md5, usage_count, int(time.time())))
else: else:
# 存在则更新timestamp设置为当前 # 存在则更新timestamp设置为当前
self.execute(""" self.__execute__("""
update `api_key_usage` set `usage` = {}, `timestamp` = {} where `key_md5` = '{}' update `api_key_usage` set `usage` = {}, `timestamp` = {} where `key_md5` = '{}'
""".format(usage_count, int(time.time()), key_md5)) """.format(usage_count, int(time.time()), key_md5))
def load_api_key_usage(self): def load_api_key_usage(self):
self.execute(""" self.__execute__("""
select `key_md5`, `usage` from `api_key_usage` select `key_md5`, `usage` from `api_key_usage`
""") """)
results = self.cursor.fetchall() results = self.cursor.fetchall()
@@ -273,23 +281,24 @@ class DatabaseManager:
return usage return usage
def dump_usage_json(self, usage: dict): def dump_usage_json(self, usage: dict):
json_str = json.dumps(usage) json_str = json.dumps(usage)
self.execute(""" self.__execute__("""
select count(*) from `account_usage`""") select count(*) from `account_usage`""")
result = self.cursor.fetchone() result = self.cursor.fetchone()
if result[0] == 0: if result[0] == 0:
# 不存在则插入 # 不存在则插入
self.execute(""" self.__execute__("""
insert into `account_usage` (`json`) values ('{}') insert into `account_usage` (`json`) values ('{}')
""".format(json_str)) """.format(json_str))
else: else:
# 存在则更新 # 存在则更新
self.execute(""" self.__execute__("""
update `account_usage` set `json` = '{}' where `id` = 1 update `account_usage` set `json` = '{}' where `id` = 1
""".format(json_str)) """.format(json_str))
def load_usage_json(self): def load_usage_json(self):
self.execute(""" self.__execute__("""
select `json` from `account_usage` order by id desc limit 1 select `json` from `account_usage` order by id desc limit 1
""") """)
result = self.cursor.fetchone() result = self.cursor.fetchone()

View File

@@ -0,0 +1,2 @@
"""OpenAI 接口处理及会话管理相关
"""

View File

@@ -1,8 +1,13 @@
# 多情景预设值管理 # 多情景预设值管理
__current__ = "default" __current__ = "default"
"""当前默认使用的情景预设的名称
由管理员使用`!default <名称>`指令切换
"""
__prompts_from_files__ = {} __prompts_from_files__ = {}
"""从文件中读取的情景预设值"""
def read_prompt_from_file() -> str: def read_prompt_from_file() -> str:

View File

@@ -5,18 +5,26 @@ import logging
import pkg.plugin.host as plugin_host import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models import pkg.plugin.models as plugin_models
class KeysManager: class KeysManager:
api_key = {} api_key = {}
"""所有api-key"""
# api-key的使用量
# 其中键为api-key的md5值值为使用量
using_key = "" using_key = ""
"""当前使用的api-key
"""
alerted = [] alerted = []
"""已提示过超额的key
记录在此以避免重复提示
"""
# 在此list中的都是经超额报错标记过的api-key
# 记录的是key值仅在运行时有效
exceeded = [] exceeded = []
"""已超额的key
供自动切换功能识别
"""
def get_using_key(self): def get_using_key(self):
return self.using_key return self.using_key
@@ -25,8 +33,6 @@ class KeysManager:
return hashlib.md5(self.using_key.encode('utf-8')).hexdigest() return hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
def __init__(self, api_key): def __init__(self, api_key):
# if hasattr(config, 'api_key_usage_threshold'):
# self.api_key_usage_threshold = config.api_key_usage_threshold
if type(api_key) is dict: if type(api_key) is dict:
self.api_key = api_key self.api_key = api_key
@@ -42,9 +48,13 @@ class KeysManager:
self.auto_switch() self.auto_switch()
# 根据tested自动切换到可用的api-key
# 返回是否切换成功, 切换后的api-key的别名
def auto_switch(self) -> (bool, str): def auto_switch(self) -> (bool, str):
"""尝试切换api-key
Returns:
是否切换成功, 切换后的api-key的别名
"""
for key_name in self.api_key: for key_name in self.api_key:
if self.api_key[key_name] not in self.exceeded: if self.api_key[key_name] not in self.exceeded:
self.using_key = self.api_key[key_name] self.using_key = self.api_key[key_name]
@@ -68,12 +78,9 @@ class KeysManager:
def add(self, key_name, key): def add(self, key_name, key):
self.api_key[key_name] = key self.api_key[key_name] = key
# 设置当前使用的api-key使用量超限
# 这是在尝试调用api时发生超限异常时调用的
def set_current_exceeded(self): def set_current_exceeded(self):
# md5 = hashlib.md5(self.using_key.encode('utf-8')).hexdigest() """设置当前使用的api-key使用量超限
# self.usage[md5] = self.api_key_usage_threshold """
# self.fee[md5] = self.api_key_fee_threshold
self.exceeded.append(self.using_key) self.exceeded.append(self.using_key)
def get_key_name(self, api_key): def get_key_name(self, api_key):

View File

@@ -7,9 +7,12 @@ import pkg.utils.context
import pkg.audit.gatherer import pkg.audit.gatherer
from pkg.openai.modelmgr import ModelRequest, create_openai_model_request from pkg.openai.modelmgr import ModelRequest, create_openai_model_request
# 为其他模块提供与OpenAI交互的接口
class OpenAIInteract: class OpenAIInteract:
api_params = {} """OpenAI 接口封装
将文字接口和图片接口封装供调用方使用
"""
key_mgr: pkg.openai.keymgr.KeysManager = None key_mgr: pkg.openai.keymgr.KeysManager = None
@@ -20,7 +23,6 @@ class OpenAIInteract:
} }
def __init__(self, api_key: str): def __init__(self, api_key: str):
# self.api_key = api_key
self.key_mgr = pkg.openai.keymgr.KeysManager(api_key) self.key_mgr = pkg.openai.keymgr.KeysManager(api_key)
self.audit_mgr = pkg.audit.gatherer.DataGatherer() self.audit_mgr = pkg.audit.gatherer.DataGatherer()
@@ -32,7 +34,16 @@ class OpenAIInteract:
pkg.utils.context.set_openai_manager(self) pkg.utils.context.set_openai_manager(self)
# 请求OpenAI Completion # 请求OpenAI Completion
def request_completion(self, prompts): def request_completion(self, prompts) -> str:
"""请求补全接口回复
Parameters:
prompts (str): 提示语
Returns:
str: 回复
"""
config = pkg.utils.context.get_config() config = pkg.utils.context.get_config()
# 根据模型选择使用的接口 # 根据模型选择使用的接口
@@ -58,8 +69,15 @@ class OpenAIInteract:
return ai.get_message() return ai.get_message()
def request_image(self, prompt): def request_image(self, prompt) -> dict:
"""请求图片接口回复
Parameters:
prompt (str): 提示语
Returns:
dict: 响应
"""
config = pkg.utils.context.get_config() config = pkg.utils.context.get_config()
params = config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params params = config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params

View File

@@ -1,5 +1,12 @@
# 提供与模型交互的抽象接口 """OpenAI 接口底层封装
目前使用的对话接口有:
ChatCompletion - gpt-3.5-turbo 等模型
Completion - text-davinci-003 等模型
此模块封装此两个接口的请求实现,为上层提供统一的调用方式
"""
import openai, logging, threading, asyncio import openai, logging, threading, asyncio
import openai.error as aiE
COMPLETION_MODELS = { COMPLETION_MODELS = {
'text-davinci-003', 'text-davinci-003',
@@ -25,26 +32,46 @@ IMAGE_MODELS = {
} }
class ModelRequest(): class ModelRequest:
"""GPT父类""" """模型接口请求父类"""
can_chat = False
runtime:threading.Thread = None
ret = ""
proxy:str = None
def __init__(self, model_name, user_name, request_fun, http_proxy:str = None): can_chat = False
runtime: threading.Thread = None
ret = {}
proxy: str = None
request_ready = True
error_info: str = "若在没有任何错误的情况下看到这句话请带着配置文件上报Issues"
def __init__(self, model_name, user_name, request_fun, http_proxy:str = None, time_out = None):
self.model_name = model_name self.model_name = model_name
self.user_name = user_name self.user_name = user_name
self.request_fun = request_fun self.request_fun = request_fun
self.time_out = time_out
if http_proxy != None: if http_proxy != None:
self.proxy = http_proxy self.proxy = http_proxy
openai.proxy = self.proxy openai.proxy = self.proxy
self.request_ready = False
async def __a_request__(self, **kwargs): async def __a_request__(self, **kwargs):
self.ret = await self.request_fun(**kwargs) """异步请求"""
try:
self.ret:dict = await self.request_fun(**kwargs)
self.request_ready = True
except aiE.APIConnectionError as e:
self.error_info = "{}\n请检查网络连接或代理是否正常".format(e)
raise ConnectionError(self.error_info)
except ValueError as e:
self.error_info = "{}\n该错误可能是由于http_proxy格式设置错误引起的"
except Exception as e:
self.error_info = "{}\n由于请求异常产生的未知错误,请查看日志".format(e)
raise Exception(self.error_info)
def request(self, **kwargs): def request(self, **kwargs):
"""向接口发起请求"""
if self.proxy != None: #异步请求 if self.proxy != None: #异步请求
self.request_ready = False
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
self.runtime = threading.Thread( self.runtime = threading.Thread(
target=loop.run_until_complete, target=loop.run_until_complete,
@@ -64,13 +91,15 @@ class ModelRequest():
若重写该方法应检查异步线程状态或在需要检查处super该方法 若重写该方法应检查异步线程状态或在需要检查处super该方法
''' '''
if self.runtime != None and isinstance(self.runtime, threading.Thread): if self.runtime != None and isinstance(self.runtime, threading.Thread):
self.runtime.join() self.runtime.join(self.time_out)
return if self.request_ready:
return
raise Exception(self.error_info)
def get_total_tokens(self): def get_total_tokens(self):
try: try:
return self.ret['usage']['total_tokens'] return self.ret['usage']['total_tokens']
except Exception: except:
return 0 return 0
def get_message(self): def get_message(self):
@@ -79,8 +108,10 @@ class ModelRequest():
def get_response(self): def get_response(self):
return self.ret return self.ret
class ChatCompletionModel(ModelRequest): class ChatCompletionModel(ModelRequest):
"""ChatCompletion类模型""" """ChatCompletion接口的请求实现"""
Chat_role = ['system', 'user', 'assistant'] Chat_role = ['system', 'user', 'assistant']
def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs): def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs):
if http_proxy == None: if http_proxy == None:
@@ -108,7 +139,8 @@ class ChatCompletionModel(ModelRequest):
class CompletionModel(ModelRequest): class CompletionModel(ModelRequest):
"""Completion类模型""" """Completion接口的请求实现"""
def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs): def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs):
if http_proxy == None: if http_proxy == None:
request_fun = openai.Completion.create request_fun = openai.Completion.create

View File

@@ -1,3 +1,8 @@
"""主线使用的会话管理模块
每个人、每个群单独一个sessionsession内部保留了对话的上下文
"""
import logging import logging
import threading import threading
import time import time
@@ -19,6 +24,7 @@ class SessionOfflineStatus:
ON_GOING = 'on_going' ON_GOING = 'on_going'
EXPLICITLY_CLOSED = 'explicitly_closed' EXPLICITLY_CLOSED = 'explicitly_closed'
# 重置session.prompt # 重置session.prompt
def reset_session_prompt(session_name, prompt): def reset_session_prompt(session_name, prompt):
# 备份原始数据 # 备份原始数据
@@ -43,11 +49,14 @@ def reset_session_prompt(session_name, prompt):
用户[{}]的数据已被重置,有可能是因为数据版本过旧或存储错误 用户[{}]的数据已被重置,有可能是因为数据版本过旧或存储错误
原始数据将备份在: 原始数据将备份在:
{}""".format(session_name, bak_path) {}""".format(session_name, bak_path)
) ) # 为保证多行文本格式正确故无缩进
return prompt return prompt
# 从数据加载session # 从数据加载session
def load_sessions(): def load_sessions():
"""从数据库加载sessions"""
global sessions global sessions
db_inst = pkg.utils.context.get_database_manager() db_inst = pkg.utils.context.get_database_manager()
@@ -93,10 +102,13 @@ class Session:
name = '' name = ''
prompt = [] prompt = []
"""使用list来保存会话中的回合"""
create_timestamp = 0 create_timestamp = 0
"""会话创建时间"""
last_interact_timestamp = 0 last_interact_timestamp = 0
"""上次交互(产生回复)时间"""
just_switched_to_exist_session = False just_switched_to_exist_session = False
@@ -116,7 +128,7 @@ class Session:
logging.debug('{},lock release successfully,{}'.format(self.name, self.response_lock)) logging.debug('{},lock release successfully,{}'.format(self.name, self.response_lock))
# 从配置文件获取会话预设信息 # 从配置文件获取会话预设信息
def get_default_prompt(self, use_default: str=None): def get_default_prompt(self, use_default: str = None):
config = pkg.utils.context.get_config() config = pkg.utils.context.get_config()
import pkg.openai.dprompt as dprompt import pkg.openai.dprompt as dprompt
@@ -130,7 +142,7 @@ class Session:
{ {
'role': 'user', 'role': 'user',
'content': current_default_prompt 'content': current_default_prompt
},{ }, {
'role': 'assistant', 'role': 'assistant',
'content': 'ok' 'content': 'ok'
} }
@@ -182,6 +194,8 @@ class Session:
# 请求回复 # 请求回复
# 这个函数是阻塞的 # 这个函数是阻塞的
def append(self, text: str) -> str: def append(self, text: str) -> str:
"""向session中添加一条消息返回接口回复"""
self.last_interact_timestamp = int(time.time()) self.last_interact_timestamp = int(time.time())
# 触发插件事件 # 触发插件事件
@@ -215,14 +229,14 @@ class Session:
res_ans = '\n\n'.join(res_ans_spt) res_ans = '\n\n'.join(res_ans_spt)
# 将此次对话的双方内容加入到prompt中 # 将此次对话的双方内容加入到prompt中
self.prompt.append({'role':'user', 'content':text}) self.prompt.append({'role': 'user', 'content': text})
self.prompt.append({'role':'assistant', 'content':res_ans}) self.prompt.append({'role': 'assistant', 'content': res_ans})
if self.just_switched_to_exist_session: if self.just_switched_to_exist_session:
self.just_switched_to_exist_session = False self.just_switched_to_exist_session = False
self.set_ongoing() self.set_ongoing()
return res_ans if res_ans[0]!='\n' else res_ans[1:] return res_ans if res_ans[0] != '\n' else res_ans[1:]
# 删除上一回合并返回上一回合的问题 # 删除上一回合并返回上一回合的问题
def undo(self) -> str: def undo(self) -> str:
@@ -231,10 +245,10 @@ class Session:
# 删除最后两个消息 # 删除最后两个消息
if len(self.prompt) < 2: if len(self.prompt) < 2:
raise Exception('之前无对话,无法撤销') raise Exception('之前无对话,无法撤销')
question = self.prompt[-2]['content'] question = self.prompt[-2]['content']
self.prompt = self.prompt[:-2] self.prompt = self.prompt[:-2]
# 返回上一回合的问题 # 返回上一回合的问题
return question return question
@@ -242,13 +256,13 @@ class Session:
def cut_out(self, msg: str, max_tokens: int) -> list: def cut_out(self, msg: str, max_tokens: int) -> list:
"""将现有prompt进行切割处理使得新的prompt长度不超过max_tokens""" """将现有prompt进行切割处理使得新的prompt长度不超过max_tokens"""
# 如果用户消息长度超过max_tokens直接返回 # 如果用户消息长度超过max_tokens直接返回
temp_prompt = [ temp_prompt = [
{ {
'role': 'user', 'role': 'user',
'content': msg 'content': msg
} }
] ]
token_count = len(msg) token_count = len(msg)
# 倒序遍历prompt # 倒序遍历prompt

View File

@@ -0,0 +1,4 @@
"""插件支持包
包含插件基类、插件宿主以及部分API接口
"""

View File

@@ -116,7 +116,9 @@ def initialize_plugins():
def unload_plugins(): def unload_plugins():
""" 卸载插件 """ """ 卸载插件
"""
# 不再显式卸载插件,因为当程序结束时,插件的析构函数会被系统执行
# for plugin in __plugins__.values(): # for plugin in __plugins__.values():
# if plugin['enabled'] and plugin['instance'] is not None: # if plugin['enabled'] and plugin['instance'] is not None:
# if not hasattr(plugin['instance'], '__del__'): # if not hasattr(plugin['instance'], '__del__'):

View File

@@ -145,6 +145,7 @@ __current_registering_plugin__ = ""
class Plugin: class Plugin:
"""插件基类"""
host: host.PluginHost host: host.PluginHost
"""插件宿主,提供插件的一些基础功能""" """插件宿主,提供插件的一些基础功能"""

View File

@@ -2,6 +2,7 @@ import asyncio
import json import json
import os import os
import threading import threading
from concurrent.futures import ThreadPoolExecutor
import mirai.models.bus import mirai.models.bus
from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \ from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
@@ -20,13 +21,6 @@ import pkg.utils.context
import pkg.plugin.host as plugin_host import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models import pkg.plugin.models as plugin_models
# 并行运行
def go(func, args=()):
thread = threading.Thread(target=func, args=args, daemon=True)
thread.start()
# 检查消息是否符合泛响应匹配机制 # 检查消息是否符合泛响应匹配机制
def check_response_rule(text: str): def check_response_rule(text: str):
config = pkg.utils.context.get_config() config = pkg.utils.context.get_config()
@@ -55,6 +49,9 @@ def check_response_rule(text: str):
class QQBotManager: class QQBotManager:
retry = 3 retry = 3
#线程池控制
pool = None
bot: Mirai = None bot: Mirai = None
reply_filter = None reply_filter = None
@@ -64,11 +61,14 @@ class QQBotManager:
ban_person = [] ban_person = []
ban_group = [] ban_group = []
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, first_time_init=True): def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, pool_num: int = 10, first_time_init=True):
self.timeout = timeout self.timeout = timeout
self.retry = retry self.retry = retry
self.pool_num = pool_num
self.pool = ThreadPoolExecutor(max_workers=self.pool_num)
logging.debug("Registered thread pool Size:{}".format(pool_num))
# 加载禁用列表 # 加载禁用列表
if os.path.exists("banlist.py"): if os.path.exists("banlist.py"):
import banlist import banlist
@@ -116,7 +116,7 @@ class QQBotManager:
self.on_person_message(event) self.on_person_message(event)
go(friend_message_handler, (event,)) self.go(friend_message_handler, event)
@self.bot.on(StrangerMessage) @self.bot.on(StrangerMessage)
async def on_stranger_message(event: StrangerMessage): async def on_stranger_message(event: StrangerMessage):
@@ -136,7 +136,7 @@ class QQBotManager:
self.on_person_message(event) self.on_person_message(event)
go(stranger_message_handler, (event,)) self.go(stranger_message_handler, event)
@self.bot.on(GroupMessage) @self.bot.on(GroupMessage)
async def on_group_message(event: GroupMessage): async def on_group_message(event: GroupMessage):
@@ -156,7 +156,7 @@ class QQBotManager:
self.on_group_message(event) self.on_group_message(event)
go(group_message_handler, (event,)) self.go(group_message_handler, event)
def unsubscribe_all(): def unsubscribe_all():
"""取消所有订阅 """取消所有订阅
@@ -173,6 +173,9 @@ class QQBotManager:
self.unsubscribe_all = unsubscribe_all self.unsubscribe_all = unsubscribe_all
def go(self, func, *args, **kwargs):
self.pool.submit(func, *args, **kwargs)
def first_time_init(self, mirai_http_api_config: dict): def first_time_init(self, mirai_http_api_config: dict):
"""热重载后不再运行此函数""" """热重载后不再运行此函数"""

File diff suppressed because one or more lines are too long

View File

@@ -7,13 +7,15 @@ import pkg.utils.context
import pkg.plugin.host import pkg.plugin.host
def walk(module, prefix=''): def walk(module, prefix='', path_prefix=''):
"""遍历并重载所有模块""" """遍历并重载所有模块"""
for item in pkgutil.iter_modules(module.__path__): for item in pkgutil.iter_modules(module.__path__):
if item.ispkg: if item.ispkg:
walk(__import__(module.__name__ + '.' + item.name, fromlist=['']), prefix + item.name + '.')
walk(__import__(module.__name__ + '.' + item.name, fromlist=['']), prefix + item.name + '.', path_prefix + item.name + '/')
else: else:
logging.info('reload module: {}'.format(prefix + item.name)) logging.info('reload module: {}, path: {}'.format(prefix + item.name, path_prefix + item.name + '.py'))
pkg.plugin.host.__current_module_path__ = "plugins/" + path_prefix + item.name + '.py'
importlib.reload(__import__(module.__name__ + '.' + item.name, fromlist=[''])) importlib.reload(__import__(module.__name__ + '.' + item.name, fromlist=['']))

View File

@@ -1,8 +1,37 @@
import logging
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import re import re
import os import os
import config
import traceback
text_render_font = ImageFont.truetype("res/simhei.ttf", 32, encoding="utf-8") text_render_font: ImageFont = None
if hasattr(config, "blob_message_strategy") and config.blob_message_strategy == "image": # 仅在启用了image时才加载字体
use_font = config.font_path if hasattr(config, "font_path") else ""
try:
# 检查是否存在
if not os.path.exists(use_font):
# 若是windows系统使用微软雅黑
if os.name == "nt":
use_font = "C:/Windows/Fonts/msyh.ttc"
if not os.path.exists(use_font):
logging.warn("未找到字体文件且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。")
config.blob_message_strategy = "forward"
else:
logging.info("使用Windows自带字体" + use_font)
text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8")
else:
logging.warn("未找到字体文件且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。")
config.blob_message_strategy = "forward"
else:
text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8")
except:
traceback.print_exc()
logging.error("加载字体文件失败({})更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。".format(use_font))
config.blob_message_strategy = "forward"
def indexNumber(path=''): def indexNumber(path=''):
@@ -123,7 +152,7 @@ def text_to_image(text_str: str, save_as="temp.png", width=800):
else: else:
continue continue
# 准备画布 # 准备画布
img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 45)), (255, 255, 255, 255)) img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255))
draw = ImageDraw.Draw(img, mode='RGBA') draw = ImageDraw.Draw(img, mode='RGBA')

View File

@@ -5,7 +5,7 @@ import os.path
import requests import requests
import json import json
import pkg.utils.context import pkg.utils.constants
def check_dulwich_closure(): def check_dulwich_closure():
@@ -46,7 +46,7 @@ def get_release_list() -> list:
def get_current_tag() -> str: def get_current_tag() -> str:
"""获取当前tag""" """获取当前tag"""
current_tag = "v0.1.0" current_tag = pkg.utils.constants.semantic_version
if os.path.exists("current_tag"): if os.path.exists("current_tag"):
with open("current_tag", "r") as f: with open("current_tag", "r") as f:
current_tag = f.read() current_tag = f.read()
@@ -54,7 +54,7 @@ def get_current_tag() -> str:
return current_tag return current_tag
def update_all() -> bool: def update_all(cli: bool = False) -> bool:
"""检查更新并下载源码""" """检查更新并下载源码"""
current_tag = get_current_tag() current_tag = get_current_tag()
@@ -69,12 +69,19 @@ def update_all() -> bool:
if latest_rls == {}: if latest_rls == {}:
latest_rls = rls latest_rls = rls
logging.info("更新日志: {}".format(rls_notes)) if not cli:
logging.info("更新日志: {}".format(rls_notes))
else:
print("更新日志: {}".format(rls_notes))
if latest_rls == {}: # 没有新版本 if latest_rls == {}: # 没有新版本
return False return False
# 下载最新版本的zip到temp目录 # 下载最新版本的zip到temp目录
logging.info("开始下载最新版本: {}".format(latest_rls['zipball_url'])) if not cli:
logging.info("开始下载最新版本: {}".format(latest_rls['zipball_url']))
else:
print("开始下载最新版本: {}".format(latest_rls['zipball_url']))
zip_url = latest_rls['zipball_url'] zip_url = latest_rls['zipball_url']
zip_resp = requests.get(url=zip_url) zip_resp = requests.get(url=zip_url)
zip_data = zip_resp.content zip_data = zip_resp.content
@@ -87,7 +94,10 @@ def update_all() -> bool:
with open("temp/updater/{}.zip".format(latest_rls['tag_name']), "wb") as f: with open("temp/updater/{}.zip".format(latest_rls['tag_name']), "wb") as f:
f.write(zip_data) f.write(zip_data)
logging.info("下载最新版本完成: {}".format("temp/updater/{}.zip".format(latest_rls['tag_name']))) if not cli:
logging.info("下载最新版本完成: {}".format("temp/updater/{}.zip".format(latest_rls['tag_name'])))
else:
print("下载最新版本完成: {}".format("temp/updater/{}.zip".format(latest_rls['tag_name'])))
# 解压zip到temp/updater/<tag_name>/ # 解压zip到temp/updater/<tag_name>/
import zipfile import zipfile
@@ -124,8 +134,11 @@ def update_all() -> bool:
f.write(current_tag) f.write(current_tag)
# 通知管理员 # 通知管理员
import pkg.utils.context if not cli:
pkg.utils.context.get_qqbot_manager().notify_admin("已更新到最新版本: {}\n更新日志:\n{}\n新功能通常可以在config-template.py中看到完整的更新日志请前往 https://github.com/RockChinQ/QChatGPT/releases 查看".format(current_tag, "\n".join(rls_notes))) import pkg.utils.context
pkg.utils.context.get_qqbot_manager().notify_admin("已更新到最新版本: {}\n更新日志:\n{}\n新功能通常可以在config-template.py中看到完整的更新日志请前往 https://github.com/RockChinQ/QChatGPT/releases 查看".format(current_tag, "\n".join(rls_notes)))
else:
print("已更新到最新版本: {}\n更新日志:\n{}\n新功能通常可以在config-template.py中看到完整的更新日志请前往 https://github.com/RockChinQ/QChatGPT/releases 查看".format(current_tag, "\n".join(rls_notes)))
return True return True
@@ -152,24 +165,12 @@ def get_remote_url(repo_path: str) -> str:
def get_current_version_info() -> str: def get_current_version_info() -> str:
"""获取当前版本信息""" """获取当前版本信息"""
check_dulwich_closure() rls_list = get_release_list()
current_tag = get_current_tag()
from dulwich import porcelain for rls in rls_list:
if rls['tag_name'] == current_tag:
repo = porcelain.open_repo('.') return rls['name'] + "\n" + rls['body']
return "未知版本"
version_str = ""
for entry in repo.get_walker():
version_str += "提交编号: "+str(entry.commit.id)[2:9] + "\n"
tz = datetime.timezone(datetime.timedelta(hours=entry.commit.commit_timezone // 3600))
dt = datetime.datetime.fromtimestamp(entry.commit.commit_time, tz)
version_str += "时间: "+dt.strftime('%m-%d %H:%M:%S') + "\n"
version_str += "说明: "+str(entry.commit.message, encoding="utf-8").strip() + "\n"
version_str += "提交作者: '" + str(entry.commit.author)[2:-1] + "'"
break
return version_str
def get_commit_id_and_time_and_msg() -> str: def get_commit_id_and_time_and_msg() -> str:

Binary file not shown.