mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a1bfbad24e | ||
|
|
8af4918048 | ||
|
|
49f4ab0ec8 | ||
|
|
85c623fb0f | ||
|
|
9e28298250 | ||
|
|
7a04ef0985 | ||
|
|
83005e9ba9 | ||
|
|
f0c78f0529 | ||
|
|
3f638adcf9 | ||
|
|
d9405d8d5d | ||
|
|
606713a418 | ||
|
|
52102f0d0a | ||
|
|
61c29829ed | ||
|
|
df30931aad | ||
|
|
5afcc03e8b | ||
|
|
fbeb4673f4 | ||
|
|
4aba319560 | ||
|
|
74f79e002c | ||
|
|
2668ef2b3f | ||
|
|
74c018e271 | ||
|
|
64776fd601 | ||
|
|
59877bf71d | ||
|
|
d2800ac58b | ||
|
|
ffef944119 | ||
|
|
651b291ef6 | ||
|
|
e4b581f197 | ||
|
|
4f3939e2d9 | ||
|
|
1048ca612d | ||
|
|
b1a2d21ee9 | ||
|
|
dd4e8bdc8b | ||
|
|
e28c9bae0c |
@@ -115,10 +115,9 @@
|
|||||||
|
|
||||||
### - 注册OpenAI账号
|
### - 注册OpenAI账号
|
||||||
|
|
||||||
**可以直接进群找群主购买**
|
参考以下文章自行注册
|
||||||
或参考以下文章自行注册
|
|
||||||
|
|
||||||
> ~~[只需 1 元搞定 ChatGPT 注册](https://zhuanlan.zhihu.com/p/589470082)~~(已失效)
|
> [国内注册ChatGPT的方法(100%可用)](https://www.pythonthree.com/register-openai-chatgpt/)
|
||||||
> [手把手教你如何注册ChatGPT,超级详细](https://guxiaobei.com/51461)
|
> [手把手教你如何注册ChatGPT,超级详细](https://guxiaobei.com/51461)
|
||||||
|
|
||||||
注册成功后请前往[个人中心查看](https://beta.openai.com/account/api-keys)api_key
|
注册成功后请前往[个人中心查看](https://beta.openai.com/account/api-keys)api_key
|
||||||
@@ -227,6 +226,7 @@ python3 main.py
|
|||||||
- [@dominoar](https://github.com/dominoar) 为本项目开发多种插件
|
- [@dominoar](https://github.com/dominoar) 为本项目开发多种插件
|
||||||
- [@hissincn](https://github.com/hissincn) 本项目贡献者
|
- [@hissincn](https://github.com/hissincn) 本项目贡献者
|
||||||
- [@LINSTCL](https://github.com/LINSTCL) GPT-3.5官方模型适配贡献者
|
- [@LINSTCL](https://github.com/LINSTCL) GPT-3.5官方模型适配贡献者
|
||||||
|
- [@Haibersut](https://github.com/Haibersut) 本项目贡献者
|
||||||
|
|
||||||
以及其他所有为本项目提供支持的朋友们。
|
以及其他所有为本项目提供支持的朋友们。
|
||||||
|
|
||||||
|
|||||||
@@ -183,6 +183,12 @@ blob_message_threshold = 256
|
|||||||
# - "forward": 将长消息转换为转发消息组件发送
|
# - "forward": 将长消息转换为转发消息组件发送
|
||||||
blob_message_strategy = "forward"
|
blob_message_strategy = "forward"
|
||||||
|
|
||||||
|
# 文字转图片时使用的字体文件路径
|
||||||
|
# 当策略为"image"时生效
|
||||||
|
# 若在Windows系统下,程序会自动使用Windows自带的微软雅黑字体
|
||||||
|
# 若未填写或不存在且不是Windows,将禁用文字转图片功能,改为使用转发消息组件
|
||||||
|
font_path = ""
|
||||||
|
|
||||||
# 消息处理超时重试次数
|
# 消息处理超时重试次数
|
||||||
retry_times = 3
|
retry_times = 3
|
||||||
|
|
||||||
@@ -196,6 +202,11 @@ hide_exce_info_to_user = False
|
|||||||
# 设置为空字符串时,不发送提示信息
|
# 设置为空字符串时,不发送提示信息
|
||||||
alter_tip_message = '出错了,请稍后再试'
|
alter_tip_message = '出错了,请稍后再试'
|
||||||
|
|
||||||
|
# 机器人线程池大小
|
||||||
|
# 该参数决定机器人可以同时处理几个人的消息,超出线程池数量的请求会被阻塞,不会被丢弃
|
||||||
|
# 如果你不清楚该参数的意义,请不要更改
|
||||||
|
pool_num = 10
|
||||||
|
|
||||||
# 每个会话的过期时间,单位为秒
|
# 每个会话的过期时间,单位为秒
|
||||||
# 默认值20分钟
|
# 默认值20分钟
|
||||||
session_expire_time = 60 * 20
|
session_expire_time = 60 * 20
|
||||||
|
|||||||
41
main.py
41
main.py
@@ -45,7 +45,9 @@ def init_db():
|
|||||||
|
|
||||||
def ensure_dependencies():
|
def ensure_dependencies():
|
||||||
import pkg.utils.pkgmgr as pkgmgr
|
import pkg.utils.pkgmgr as pkgmgr
|
||||||
pkgmgr.run_pip(["install", "openai", "Pillow", "--upgrade"])
|
pkgmgr.run_pip(["install", "openai", "Pillow", "--upgrade",
|
||||||
|
"-i", "https://pypi.douban.com/simple/",
|
||||||
|
"--trusted-host", "pypi.douban.com"])
|
||||||
|
|
||||||
|
|
||||||
known_exception_caught = False
|
known_exception_caught = False
|
||||||
@@ -127,13 +129,26 @@ def main(first_time_init=False):
|
|||||||
|
|
||||||
config = importlib.import_module('config')
|
config = importlib.import_module('config')
|
||||||
|
|
||||||
import pkg.utils.context
|
|
||||||
pkg.utils.context.set_config(config)
|
|
||||||
|
|
||||||
init_runtime_log_file()
|
init_runtime_log_file()
|
||||||
|
|
||||||
sh = reset_logging()
|
sh = reset_logging()
|
||||||
|
|
||||||
|
# 配置完整性校验
|
||||||
|
is_integrity = True
|
||||||
|
config_template = importlib.import_module('config-template')
|
||||||
|
for key in dir(config_template):
|
||||||
|
if not key.startswith("__") and not hasattr(config, key):
|
||||||
|
setattr(config, key, getattr(config_template, key))
|
||||||
|
logging.warning("[{}]不存在".format(key))
|
||||||
|
is_integrity = False
|
||||||
|
if not is_integrity:
|
||||||
|
logging.warning("配置文件不完整,请依据config-template.py检查config.py")
|
||||||
|
logging.warning("以上配置已被设为默认值,将在5秒后继续启动... ")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
import pkg.utils.context
|
||||||
|
pkg.utils.context.set_config(config)
|
||||||
|
|
||||||
# 检查是否设置了管理员
|
# 检查是否设置了管理员
|
||||||
if not (hasattr(config, 'admin_qq') and config.admin_qq != 0):
|
if not (hasattr(config, 'admin_qq') and config.admin_qq != 0):
|
||||||
# logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
|
# logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
|
||||||
@@ -180,7 +195,7 @@ def main(first_time_init=False):
|
|||||||
# 初始化qq机器人
|
# 初始化qq机器人
|
||||||
qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config,
|
qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config,
|
||||||
timeout=config.process_message_timeout, retry=config.retry_times,
|
timeout=config.process_message_timeout, retry=config.retry_times,
|
||||||
first_time_init=first_time_init)
|
first_time_init=first_time_init, pool_num=config.pool_num)
|
||||||
|
|
||||||
# 加载插件
|
# 加载插件
|
||||||
import pkg.plugin.host
|
import pkg.plugin.host
|
||||||
@@ -340,19 +355,9 @@ if __name__ == '__main__':
|
|||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
elif len(sys.argv) > 1 and sys.argv[1] == 'update':
|
elif len(sys.argv) > 1 and sys.argv[1] == 'update':
|
||||||
try:
|
print("正在进行程序更新...")
|
||||||
try:
|
import pkg.utils.updater as updater
|
||||||
import pkg.utils.pkgmgr
|
updater.update_all(cli=True)
|
||||||
pkg.utils.pkgmgr.ensure_dulwich()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
from dulwich import porcelain
|
|
||||||
|
|
||||||
repo = porcelain.open_repo('.')
|
|
||||||
porcelain.pull(repo)
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
print("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77")
|
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
# import pkg.utils.configmgr
|
# import pkg.utils.configmgr
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
审计相关操作
|
||||||
|
"""
|
||||||
@@ -1,3 +1,7 @@
|
|||||||
|
"""
|
||||||
|
使用量统计以及数据上报功能实现
|
||||||
|
"""
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -10,8 +14,11 @@ import pkg.utils.updater
|
|||||||
|
|
||||||
class DataGatherer:
|
class DataGatherer:
|
||||||
"""数据收集器"""
|
"""数据收集器"""
|
||||||
|
|
||||||
usage = {}
|
usage = {}
|
||||||
"""以key值md5为key,{
|
"""各api-key的使用量
|
||||||
|
|
||||||
|
以key值md5为key,{
|
||||||
"text": {
|
"text": {
|
||||||
"text-davinci-003": 文字量:int,
|
"text-davinci-003": 文字量:int,
|
||||||
},
|
},
|
||||||
@@ -25,11 +32,16 @@ class DataGatherer:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.load_from_db()
|
self.load_from_db()
|
||||||
try:
|
try:
|
||||||
self.version_str = pkg.utils.updater.get_commit_id_and_time_and_msg()[:40 if len(pkg.utils.updater.get_commit_id_and_time_and_msg()) > 40 else len(pkg.utils.updater.get_commit_id_and_time_and_msg())]
|
self.version_str = pkg.utils.updater.get_current_tag() # 从updater模块获取版本号
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def report_to_server(self, subservice_name: str, count: int):
|
def report_to_server(self, subservice_name: str, count: int):
|
||||||
|
"""向中央服务器报告使用量
|
||||||
|
|
||||||
|
只会报告此次请求的使用量,不会报告总量。
|
||||||
|
不包含除版本号、使用类型、使用量以外的任何信息,仅供开发者分析使用情况。
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
config = pkg.utils.context.get_config()
|
config = pkg.utils.context.get_config()
|
||||||
if hasattr(config, "report_usage") and not config.report_usage:
|
if hasattr(config, "report_usage") and not config.report_usage:
|
||||||
@@ -44,7 +56,9 @@ class DataGatherer:
|
|||||||
return self.usage[key_md5] if key_md5 in self.usage else {}
|
return self.usage[key_md5] if key_md5 in self.usage else {}
|
||||||
|
|
||||||
def report_text_model_usage(self, model, total_tokens):
|
def report_text_model_usage(self, model, total_tokens):
|
||||||
key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()
|
"""调用方报告文字模型请求文字使用量"""
|
||||||
|
|
||||||
|
key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5() # 以key的md5进行储存
|
||||||
|
|
||||||
if key_md5 not in self.usage:
|
if key_md5 not in self.usage:
|
||||||
self.usage[key_md5] = {}
|
self.usage[key_md5] = {}
|
||||||
@@ -62,6 +76,8 @@ class DataGatherer:
|
|||||||
self.report_to_server("text", length)
|
self.report_to_server("text", length)
|
||||||
|
|
||||||
def report_image_model_usage(self, size):
|
def report_image_model_usage(self, size):
|
||||||
|
"""调用方报告图片模型请求图片使用量"""
|
||||||
|
|
||||||
key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()
|
key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5()
|
||||||
|
|
||||||
if key_md5 not in self.usage:
|
if key_md5 not in self.usage:
|
||||||
@@ -79,6 +95,7 @@ class DataGatherer:
|
|||||||
self.report_to_server("image", 1)
|
self.report_to_server("image", 1)
|
||||||
|
|
||||||
def get_text_length_of_key(self, key):
|
def get_text_length_of_key(self, key):
|
||||||
|
"""获取指定api-key (明文) 的文字总使用量(本地记录)"""
|
||||||
key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest()
|
key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest()
|
||||||
if key_md5 not in self.usage:
|
if key_md5 not in self.usage:
|
||||||
return 0
|
return 0
|
||||||
@@ -88,6 +105,8 @@ class DataGatherer:
|
|||||||
return sum(self.usage[key_md5]["text"].values())
|
return sum(self.usage[key_md5]["text"].values())
|
||||||
|
|
||||||
def get_image_count_of_key(self, key):
|
def get_image_count_of_key(self, key):
|
||||||
|
"""获取指定api-key (明文) 的图片总使用量(本地记录)"""
|
||||||
|
|
||||||
key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest()
|
key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest()
|
||||||
if key_md5 not in self.usage:
|
if key_md5 not in self.usage:
|
||||||
return 0
|
return 0
|
||||||
@@ -97,6 +116,7 @@ class DataGatherer:
|
|||||||
return sum(self.usage[key_md5]["image"].values())
|
return sum(self.usage[key_md5]["image"].values())
|
||||||
|
|
||||||
def get_total_text_length(self):
|
def get_total_text_length(self):
|
||||||
|
"""获取所有api-key的文字总使用量(本地记录)"""
|
||||||
total = 0
|
total = 0
|
||||||
for key in self.usage:
|
for key in self.usage:
|
||||||
if "text" not in self.usage[key]:
|
if "text" not in self.usage[key]:
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
数据库操作封装
|
||||||
|
"""
|
||||||
@@ -1,3 +1,6 @@
|
|||||||
|
"""
|
||||||
|
数据库管理模块
|
||||||
|
"""
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -9,9 +12,9 @@ import sqlite3
|
|||||||
import pkg.utils.context
|
import pkg.utils.context
|
||||||
|
|
||||||
|
|
||||||
# 数据库管理
|
|
||||||
# 为其他模块提供数据库操作接口
|
|
||||||
class DatabaseManager:
|
class DatabaseManager:
|
||||||
|
"""封装数据库底层操作,并提供方法给上层使用"""
|
||||||
|
|
||||||
conn = None
|
conn = None
|
||||||
cursor = None
|
cursor = None
|
||||||
|
|
||||||
@@ -23,13 +26,14 @@ class DatabaseManager:
|
|||||||
|
|
||||||
# 连接到数据库文件
|
# 连接到数据库文件
|
||||||
def reconnect(self):
|
def reconnect(self):
|
||||||
|
"""连接到数据库"""
|
||||||
self.conn = sqlite3.connect('database.db', check_same_thread=False)
|
self.conn = sqlite3.connect('database.db', check_same_thread=False)
|
||||||
self.cursor = self.conn.cursor()
|
self.cursor = self.conn.cursor()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
|
|
||||||
def execute(self, *args, **kwargs) -> Cursor:
|
def __execute__(self, *args, **kwargs) -> Cursor:
|
||||||
# logging.debug('SQL: {}'.format(sql))
|
# logging.debug('SQL: {}'.format(sql))
|
||||||
c = self.cursor.execute(*args, **kwargs)
|
c = self.cursor.execute(*args, **kwargs)
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
@@ -37,7 +41,9 @@ class DatabaseManager:
|
|||||||
|
|
||||||
# 初始化数据库的函数
|
# 初始化数据库的函数
|
||||||
def initialize_database(self):
|
def initialize_database(self):
|
||||||
self.execute("""
|
"""创建数据表"""
|
||||||
|
|
||||||
|
self.__execute__("""
|
||||||
create table if not exists `sessions` (
|
create table if not exists `sessions` (
|
||||||
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
|
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
`name` varchar(255) not null,
|
`name` varchar(255) not null,
|
||||||
@@ -50,7 +56,7 @@ class DatabaseManager:
|
|||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
self.execute("""
|
self.__execute__("""
|
||||||
create table if not exists `account_fee`(
|
create table if not exists `account_fee`(
|
||||||
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
|
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
`key_md5` varchar(255) not null,
|
`key_md5` varchar(255) not null,
|
||||||
@@ -59,7 +65,7 @@ class DatabaseManager:
|
|||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
self.execute("""
|
self.__execute__("""
|
||||||
create table if not exists `account_usage`(
|
create table if not exists `account_usage`(
|
||||||
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
|
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
`json` text not null
|
`json` text not null
|
||||||
@@ -70,10 +76,12 @@ class DatabaseManager:
|
|||||||
# session持久化
|
# session持久化
|
||||||
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
|
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
|
||||||
last_interact_timestamp: int, prompt: str):
|
last_interact_timestamp: int, prompt: str):
|
||||||
|
"""持久化指定session"""
|
||||||
|
|
||||||
# 检查是否已经有了此name和create_timestamp的session
|
# 检查是否已经有了此name和create_timestamp的session
|
||||||
# 如果有,就更新prompt和last_interact_timestamp
|
# 如果有,就更新prompt和last_interact_timestamp
|
||||||
# 如果没有,就插入一条新的记录
|
# 如果没有,就插入一条新的记录
|
||||||
self.execute("""
|
self.__execute__("""
|
||||||
select count(*) from `sessions` where `type` = '{}' and `number` = {} and `create_timestamp` = {}
|
select count(*) from `sessions` where `type` = '{}' and `number` = {} and `create_timestamp` = {}
|
||||||
""".format(subject_type, subject_number, create_timestamp))
|
""".format(subject_type, subject_number, create_timestamp))
|
||||||
count = self.cursor.fetchone()[0]
|
count = self.cursor.fetchone()[0]
|
||||||
@@ -84,8 +92,8 @@ class DatabaseManager:
|
|||||||
values (?, ?, ?, ?, ?, ?)
|
values (?, ?, ?, ?, ?, ?)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.execute(sql,
|
self.__execute__(sql,
|
||||||
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
|
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
|
||||||
last_interact_timestamp, prompt))
|
last_interact_timestamp, prompt))
|
||||||
else:
|
else:
|
||||||
sql = """
|
sql = """
|
||||||
@@ -93,23 +101,23 @@ class DatabaseManager:
|
|||||||
where `type` = ? and `number` = ? and `create_timestamp` = ?
|
where `type` = ? and `number` = ? and `create_timestamp` = ?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.execute(sql, (last_interact_timestamp, prompt, subject_type,
|
self.__execute__(sql, (last_interact_timestamp, prompt, subject_type,
|
||||||
subject_number, create_timestamp))
|
subject_number, create_timestamp))
|
||||||
|
|
||||||
# 显式关闭一个session
|
# 显式关闭一个session
|
||||||
def explicit_close_session(self, session_name: str, create_timestamp: int):
|
def explicit_close_session(self, session_name: str, create_timestamp: int):
|
||||||
self.execute("""
|
self.__execute__("""
|
||||||
update `sessions` set `status` = 'explicitly_closed' where `name` = '{}' and `create_timestamp` = {}
|
update `sessions` set `status` = 'explicitly_closed' where `name` = '{}' and `create_timestamp` = {}
|
||||||
""".format(session_name, create_timestamp))
|
""".format(session_name, create_timestamp))
|
||||||
|
|
||||||
def set_session_ongoing(self, session_name: str, create_timestamp: int):
|
def set_session_ongoing(self, session_name: str, create_timestamp: int):
|
||||||
self.execute("""
|
self.__execute__("""
|
||||||
update `sessions` set `status` = 'on_going' where `name` = '{}' and `create_timestamp` = {}
|
update `sessions` set `status` = 'on_going' where `name` = '{}' and `create_timestamp` = {}
|
||||||
""".format(session_name, create_timestamp))
|
""".format(session_name, create_timestamp))
|
||||||
|
|
||||||
# 设置session为过期
|
# 设置session为过期
|
||||||
def set_session_expired(self, session_name: str, create_timestamp: int):
|
def set_session_expired(self, session_name: str, create_timestamp: int):
|
||||||
self.execute("""
|
self.__execute__("""
|
||||||
update `sessions` set `status` = 'expired' where `name` = '{}' and `create_timestamp` = {}
|
update `sessions` set `status` = 'expired' where `name` = '{}' and `create_timestamp` = {}
|
||||||
""".format(session_name, create_timestamp))
|
""".format(session_name, create_timestamp))
|
||||||
|
|
||||||
@@ -117,7 +125,7 @@ class DatabaseManager:
|
|||||||
def load_valid_sessions(self) -> dict:
|
def load_valid_sessions(self) -> dict:
|
||||||
# 从数据库中加载所有还没过期的session
|
# 从数据库中加载所有还没过期的session
|
||||||
config = pkg.utils.context.get_config()
|
config = pkg.utils.context.get_config()
|
||||||
self.execute("""
|
self.__execute__("""
|
||||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
|
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
|
||||||
from `sessions` where `last_interact_timestamp` > {}
|
from `sessions` where `last_interact_timestamp` > {}
|
||||||
""".format(int(time.time()) - config.session_expire_time))
|
""".format(int(time.time()) - config.session_expire_time))
|
||||||
@@ -150,7 +158,7 @@ class DatabaseManager:
|
|||||||
# 获取此session_name前一个session的数据
|
# 获取此session_name前一个session的数据
|
||||||
def last_session(self, session_name: str, cursor_timestamp: int):
|
def last_session(self, session_name: str, cursor_timestamp: int):
|
||||||
|
|
||||||
self.execute("""
|
self.__execute__("""
|
||||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
|
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
|
||||||
from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc
|
from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc
|
||||||
limit 1
|
limit 1
|
||||||
@@ -179,7 +187,7 @@ class DatabaseManager:
|
|||||||
# 获取此session_name后一个session的数据
|
# 获取此session_name后一个session的数据
|
||||||
def next_session(self, session_name: str, cursor_timestamp: int):
|
def next_session(self, session_name: str, cursor_timestamp: int):
|
||||||
|
|
||||||
self.execute("""
|
self.__execute__("""
|
||||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
|
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
|
||||||
from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc
|
from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc
|
||||||
limit 1
|
limit 1
|
||||||
@@ -207,7 +215,7 @@ class DatabaseManager:
|
|||||||
|
|
||||||
# 列出与某个对象的所有对话session
|
# 列出与某个对象的所有对话session
|
||||||
def list_history(self, session_name: str, capacity: int, page: int):
|
def list_history(self, session_name: str, capacity: int, page: int):
|
||||||
self.execute("""
|
self.__execute__("""
|
||||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
|
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
|
||||||
from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {}
|
from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {}
|
||||||
""".format(session_name, capacity, capacity * page))
|
""".format(session_name, capacity, capacity * page))
|
||||||
@@ -246,22 +254,22 @@ class DatabaseManager:
|
|||||||
usage_count = usage[key_md5]
|
usage_count = usage[key_md5]
|
||||||
# 将使用量存进数据库
|
# 将使用量存进数据库
|
||||||
# 先检查是否已存在
|
# 先检查是否已存在
|
||||||
self.execute("""
|
self.__execute__("""
|
||||||
select count(*) from `api_key_usage` where `key_md5` = '{}'""".format(key_md5))
|
select count(*) from `api_key_usage` where `key_md5` = '{}'""".format(key_md5))
|
||||||
result = self.cursor.fetchone()
|
result = self.cursor.fetchone()
|
||||||
if result[0] == 0:
|
if result[0] == 0:
|
||||||
# 不存在则插入
|
# 不存在则插入
|
||||||
self.execute("""
|
self.__execute__("""
|
||||||
insert into `api_key_usage` (`key_md5`, `usage`,`timestamp`) values ('{}', {}, {})
|
insert into `api_key_usage` (`key_md5`, `usage`,`timestamp`) values ('{}', {}, {})
|
||||||
""".format(key_md5, usage_count, int(time.time())))
|
""".format(key_md5, usage_count, int(time.time())))
|
||||||
else:
|
else:
|
||||||
# 存在则更新,timestamp设置为当前
|
# 存在则更新,timestamp设置为当前
|
||||||
self.execute("""
|
self.__execute__("""
|
||||||
update `api_key_usage` set `usage` = {}, `timestamp` = {} where `key_md5` = '{}'
|
update `api_key_usage` set `usage` = {}, `timestamp` = {} where `key_md5` = '{}'
|
||||||
""".format(usage_count, int(time.time()), key_md5))
|
""".format(usage_count, int(time.time()), key_md5))
|
||||||
|
|
||||||
def load_api_key_usage(self):
|
def load_api_key_usage(self):
|
||||||
self.execute("""
|
self.__execute__("""
|
||||||
select `key_md5`, `usage` from `api_key_usage`
|
select `key_md5`, `usage` from `api_key_usage`
|
||||||
""")
|
""")
|
||||||
results = self.cursor.fetchall()
|
results = self.cursor.fetchall()
|
||||||
@@ -273,23 +281,24 @@ class DatabaseManager:
|
|||||||
return usage
|
return usage
|
||||||
|
|
||||||
def dump_usage_json(self, usage: dict):
|
def dump_usage_json(self, usage: dict):
|
||||||
|
|
||||||
json_str = json.dumps(usage)
|
json_str = json.dumps(usage)
|
||||||
self.execute("""
|
self.__execute__("""
|
||||||
select count(*) from `account_usage`""")
|
select count(*) from `account_usage`""")
|
||||||
result = self.cursor.fetchone()
|
result = self.cursor.fetchone()
|
||||||
if result[0] == 0:
|
if result[0] == 0:
|
||||||
# 不存在则插入
|
# 不存在则插入
|
||||||
self.execute("""
|
self.__execute__("""
|
||||||
insert into `account_usage` (`json`) values ('{}')
|
insert into `account_usage` (`json`) values ('{}')
|
||||||
""".format(json_str))
|
""".format(json_str))
|
||||||
else:
|
else:
|
||||||
# 存在则更新
|
# 存在则更新
|
||||||
self.execute("""
|
self.__execute__("""
|
||||||
update `account_usage` set `json` = '{}' where `id` = 1
|
update `account_usage` set `json` = '{}' where `id` = 1
|
||||||
""".format(json_str))
|
""".format(json_str))
|
||||||
|
|
||||||
def load_usage_json(self):
|
def load_usage_json(self):
|
||||||
self.execute("""
|
self.__execute__("""
|
||||||
select `json` from `account_usage` order by id desc limit 1
|
select `json` from `account_usage` order by id desc limit 1
|
||||||
""")
|
""")
|
||||||
result = self.cursor.fetchone()
|
result = self.cursor.fetchone()
|
||||||
|
|||||||
@@ -0,0 +1,2 @@
|
|||||||
|
"""OpenAI 接口处理及会话管理相关
|
||||||
|
"""
|
||||||
|
|||||||
@@ -1,8 +1,13 @@
|
|||||||
# 多情景预设值管理
|
# 多情景预设值管理
|
||||||
|
|
||||||
__current__ = "default"
|
__current__ = "default"
|
||||||
|
"""当前默认使用的情景预设的名称
|
||||||
|
|
||||||
|
由管理员使用`!default <名称>`指令切换
|
||||||
|
"""
|
||||||
|
|
||||||
__prompts_from_files__ = {}
|
__prompts_from_files__ = {}
|
||||||
|
"""从文件中读取的情景预设值"""
|
||||||
|
|
||||||
|
|
||||||
def read_prompt_from_file() -> str:
|
def read_prompt_from_file() -> str:
|
||||||
|
|||||||
@@ -5,18 +5,26 @@ import logging
|
|||||||
import pkg.plugin.host as plugin_host
|
import pkg.plugin.host as plugin_host
|
||||||
import pkg.plugin.models as plugin_models
|
import pkg.plugin.models as plugin_models
|
||||||
|
|
||||||
|
|
||||||
class KeysManager:
|
class KeysManager:
|
||||||
api_key = {}
|
api_key = {}
|
||||||
|
"""所有api-key"""
|
||||||
|
|
||||||
# api-key的使用量
|
|
||||||
# 其中键为api-key的md5值,值为使用量
|
|
||||||
using_key = ""
|
using_key = ""
|
||||||
|
"""当前使用的api-key
|
||||||
|
"""
|
||||||
|
|
||||||
alerted = []
|
alerted = []
|
||||||
|
"""已提示过超额的key
|
||||||
|
|
||||||
|
记录在此以避免重复提示
|
||||||
|
"""
|
||||||
|
|
||||||
# 在此list中的都是经超额报错标记过的api-key
|
|
||||||
# 记录的是key值,仅在运行时有效
|
|
||||||
exceeded = []
|
exceeded = []
|
||||||
|
"""已超额的key
|
||||||
|
|
||||||
|
供自动切换功能识别
|
||||||
|
"""
|
||||||
|
|
||||||
def get_using_key(self):
|
def get_using_key(self):
|
||||||
return self.using_key
|
return self.using_key
|
||||||
@@ -25,8 +33,6 @@ class KeysManager:
|
|||||||
return hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
|
return hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
|
||||||
|
|
||||||
def __init__(self, api_key):
|
def __init__(self, api_key):
|
||||||
# if hasattr(config, 'api_key_usage_threshold'):
|
|
||||||
# self.api_key_usage_threshold = config.api_key_usage_threshold
|
|
||||||
|
|
||||||
if type(api_key) is dict:
|
if type(api_key) is dict:
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
@@ -42,9 +48,13 @@ class KeysManager:
|
|||||||
|
|
||||||
self.auto_switch()
|
self.auto_switch()
|
||||||
|
|
||||||
# 根据tested自动切换到可用的api-key
|
|
||||||
# 返回是否切换成功, 切换后的api-key的别名
|
|
||||||
def auto_switch(self) -> (bool, str):
|
def auto_switch(self) -> (bool, str):
|
||||||
|
"""尝试切换api-key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否切换成功, 切换后的api-key的别名
|
||||||
|
"""
|
||||||
|
|
||||||
for key_name in self.api_key:
|
for key_name in self.api_key:
|
||||||
if self.api_key[key_name] not in self.exceeded:
|
if self.api_key[key_name] not in self.exceeded:
|
||||||
self.using_key = self.api_key[key_name]
|
self.using_key = self.api_key[key_name]
|
||||||
@@ -68,12 +78,9 @@ class KeysManager:
|
|||||||
def add(self, key_name, key):
|
def add(self, key_name, key):
|
||||||
self.api_key[key_name] = key
|
self.api_key[key_name] = key
|
||||||
|
|
||||||
# 设置当前使用的api-key使用量超限
|
|
||||||
# 这是在尝试调用api时发生超限异常时调用的
|
|
||||||
def set_current_exceeded(self):
|
def set_current_exceeded(self):
|
||||||
# md5 = hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
|
"""设置当前使用的api-key使用量超限
|
||||||
# self.usage[md5] = self.api_key_usage_threshold
|
"""
|
||||||
# self.fee[md5] = self.api_key_fee_threshold
|
|
||||||
self.exceeded.append(self.using_key)
|
self.exceeded.append(self.using_key)
|
||||||
|
|
||||||
def get_key_name(self, api_key):
|
def get_key_name(self, api_key):
|
||||||
|
|||||||
@@ -7,9 +7,12 @@ import pkg.utils.context
|
|||||||
import pkg.audit.gatherer
|
import pkg.audit.gatherer
|
||||||
from pkg.openai.modelmgr import ModelRequest, create_openai_model_request
|
from pkg.openai.modelmgr import ModelRequest, create_openai_model_request
|
||||||
|
|
||||||
# 为其他模块提供与OpenAI交互的接口
|
|
||||||
class OpenAIInteract:
|
class OpenAIInteract:
|
||||||
api_params = {}
|
"""OpenAI 接口封装
|
||||||
|
|
||||||
|
将文字接口和图片接口封装供调用方使用
|
||||||
|
"""
|
||||||
|
|
||||||
key_mgr: pkg.openai.keymgr.KeysManager = None
|
key_mgr: pkg.openai.keymgr.KeysManager = None
|
||||||
|
|
||||||
@@ -20,7 +23,6 @@ class OpenAIInteract:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, api_key: str):
|
def __init__(self, api_key: str):
|
||||||
# self.api_key = api_key
|
|
||||||
|
|
||||||
self.key_mgr = pkg.openai.keymgr.KeysManager(api_key)
|
self.key_mgr = pkg.openai.keymgr.KeysManager(api_key)
|
||||||
self.audit_mgr = pkg.audit.gatherer.DataGatherer()
|
self.audit_mgr = pkg.audit.gatherer.DataGatherer()
|
||||||
@@ -32,7 +34,16 @@ class OpenAIInteract:
|
|||||||
pkg.utils.context.set_openai_manager(self)
|
pkg.utils.context.set_openai_manager(self)
|
||||||
|
|
||||||
# 请求OpenAI Completion
|
# 请求OpenAI Completion
|
||||||
def request_completion(self, prompts):
|
def request_completion(self, prompts) -> str:
|
||||||
|
"""请求补全接口回复
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
prompts (str): 提示语
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 回复
|
||||||
|
"""
|
||||||
|
|
||||||
config = pkg.utils.context.get_config()
|
config = pkg.utils.context.get_config()
|
||||||
|
|
||||||
# 根据模型选择使用的接口
|
# 根据模型选择使用的接口
|
||||||
@@ -58,8 +69,15 @@ class OpenAIInteract:
|
|||||||
|
|
||||||
return ai.get_message()
|
return ai.get_message()
|
||||||
|
|
||||||
def request_image(self, prompt):
|
def request_image(self, prompt) -> dict:
|
||||||
|
"""请求图片接口回复
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
prompt (str): 提示语
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 响应
|
||||||
|
"""
|
||||||
config = pkg.utils.context.get_config()
|
config = pkg.utils.context.get_config()
|
||||||
params = config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params
|
params = config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,12 @@
|
|||||||
# 提供与模型交互的抽象接口
|
"""OpenAI 接口底层封装
|
||||||
|
|
||||||
|
目前使用的对话接口有:
|
||||||
|
ChatCompletion - gpt-3.5-turbo 等模型
|
||||||
|
Completion - text-davinci-003 等模型
|
||||||
|
此模块封装此两个接口的请求实现,为上层提供统一的调用方式
|
||||||
|
"""
|
||||||
import openai, logging, threading, asyncio
|
import openai, logging, threading, asyncio
|
||||||
|
import openai.error as aiE
|
||||||
|
|
||||||
COMPLETION_MODELS = {
|
COMPLETION_MODELS = {
|
||||||
'text-davinci-003',
|
'text-davinci-003',
|
||||||
@@ -25,26 +32,46 @@ IMAGE_MODELS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ModelRequest():
|
class ModelRequest:
|
||||||
"""GPT父类"""
|
"""模型接口请求父类"""
|
||||||
can_chat = False
|
|
||||||
runtime:threading.Thread = None
|
|
||||||
ret = ""
|
|
||||||
proxy:str = None
|
|
||||||
|
|
||||||
def __init__(self, model_name, user_name, request_fun, http_proxy:str = None):
|
can_chat = False
|
||||||
|
runtime: threading.Thread = None
|
||||||
|
ret = {}
|
||||||
|
proxy: str = None
|
||||||
|
request_ready = True
|
||||||
|
error_info: str = "若在没有任何错误的情况下看到这句话,请带着配置文件上报Issues"
|
||||||
|
|
||||||
|
def __init__(self, model_name, user_name, request_fun, http_proxy:str = None, time_out = None):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.user_name = user_name
|
self.user_name = user_name
|
||||||
self.request_fun = request_fun
|
self.request_fun = request_fun
|
||||||
|
self.time_out = time_out
|
||||||
if http_proxy != None:
|
if http_proxy != None:
|
||||||
self.proxy = http_proxy
|
self.proxy = http_proxy
|
||||||
openai.proxy = self.proxy
|
openai.proxy = self.proxy
|
||||||
|
self.request_ready = False
|
||||||
|
|
||||||
async def __a_request__(self, **kwargs):
|
async def __a_request__(self, **kwargs):
|
||||||
self.ret = await self.request_fun(**kwargs)
|
"""异步请求"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.ret:dict = await self.request_fun(**kwargs)
|
||||||
|
self.request_ready = True
|
||||||
|
except aiE.APIConnectionError as e:
|
||||||
|
self.error_info = "{}\n请检查网络连接或代理是否正常".format(e)
|
||||||
|
raise ConnectionError(self.error_info)
|
||||||
|
except ValueError as e:
|
||||||
|
self.error_info = "{}\n该错误可能是由于http_proxy格式设置错误引起的"
|
||||||
|
except Exception as e:
|
||||||
|
self.error_info = "{}\n由于请求异常产生的未知错误,请查看日志".format(e)
|
||||||
|
raise Exception(self.error_info)
|
||||||
|
|
||||||
def request(self, **kwargs):
|
def request(self, **kwargs):
|
||||||
|
"""向接口发起请求"""
|
||||||
|
|
||||||
if self.proxy != None: #异步请求
|
if self.proxy != None: #异步请求
|
||||||
|
self.request_ready = False
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
self.runtime = threading.Thread(
|
self.runtime = threading.Thread(
|
||||||
target=loop.run_until_complete,
|
target=loop.run_until_complete,
|
||||||
@@ -64,13 +91,15 @@ class ModelRequest():
|
|||||||
若重写该方法,应检查异步线程状态,或在需要检查处super该方法
|
若重写该方法,应检查异步线程状态,或在需要检查处super该方法
|
||||||
'''
|
'''
|
||||||
if self.runtime != None and isinstance(self.runtime, threading.Thread):
|
if self.runtime != None and isinstance(self.runtime, threading.Thread):
|
||||||
self.runtime.join()
|
self.runtime.join(self.time_out)
|
||||||
return
|
if self.request_ready:
|
||||||
|
return
|
||||||
|
raise Exception(self.error_info)
|
||||||
|
|
||||||
def get_total_tokens(self):
|
def get_total_tokens(self):
|
||||||
try:
|
try:
|
||||||
return self.ret['usage']['total_tokens']
|
return self.ret['usage']['total_tokens']
|
||||||
except Exception:
|
except:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def get_message(self):
|
def get_message(self):
|
||||||
@@ -79,8 +108,10 @@ class ModelRequest():
|
|||||||
def get_response(self):
|
def get_response(self):
|
||||||
return self.ret
|
return self.ret
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionModel(ModelRequest):
|
class ChatCompletionModel(ModelRequest):
|
||||||
"""ChatCompletion类模型"""
|
"""ChatCompletion接口的请求实现"""
|
||||||
|
|
||||||
Chat_role = ['system', 'user', 'assistant']
|
Chat_role = ['system', 'user', 'assistant']
|
||||||
def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs):
|
def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs):
|
||||||
if http_proxy == None:
|
if http_proxy == None:
|
||||||
@@ -108,7 +139,8 @@ class ChatCompletionModel(ModelRequest):
|
|||||||
|
|
||||||
|
|
||||||
class CompletionModel(ModelRequest):
|
class CompletionModel(ModelRequest):
|
||||||
"""Completion类模型"""
|
"""Completion接口的请求实现"""
|
||||||
|
|
||||||
def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs):
|
def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs):
|
||||||
if http_proxy == None:
|
if http_proxy == None:
|
||||||
request_fun = openai.Completion.create
|
request_fun = openai.Completion.create
|
||||||
|
|||||||
@@ -1,3 +1,8 @@
|
|||||||
|
"""主线使用的会话管理模块
|
||||||
|
|
||||||
|
每个人、每个群单独一个session,session内部保留了对话的上下文,
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@@ -19,6 +24,7 @@ class SessionOfflineStatus:
|
|||||||
ON_GOING = 'on_going'
|
ON_GOING = 'on_going'
|
||||||
EXPLICITLY_CLOSED = 'explicitly_closed'
|
EXPLICITLY_CLOSED = 'explicitly_closed'
|
||||||
|
|
||||||
|
|
||||||
# 重置session.prompt
|
# 重置session.prompt
|
||||||
def reset_session_prompt(session_name, prompt):
|
def reset_session_prompt(session_name, prompt):
|
||||||
# 备份原始数据
|
# 备份原始数据
|
||||||
@@ -43,11 +49,14 @@ def reset_session_prompt(session_name, prompt):
|
|||||||
用户[{}]的数据已被重置,有可能是因为数据版本过旧或存储错误
|
用户[{}]的数据已被重置,有可能是因为数据版本过旧或存储错误
|
||||||
原始数据将备份在:
|
原始数据将备份在:
|
||||||
{}""".format(session_name, bak_path)
|
{}""".format(session_name, bak_path)
|
||||||
)
|
) # 为保证多行文本格式正确故无缩进
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
# 从数据加载session
|
# 从数据加载session
|
||||||
def load_sessions():
|
def load_sessions():
|
||||||
|
"""从数据库加载sessions"""
|
||||||
|
|
||||||
global sessions
|
global sessions
|
||||||
|
|
||||||
db_inst = pkg.utils.context.get_database_manager()
|
db_inst = pkg.utils.context.get_database_manager()
|
||||||
@@ -93,10 +102,13 @@ class Session:
|
|||||||
name = ''
|
name = ''
|
||||||
|
|
||||||
prompt = []
|
prompt = []
|
||||||
|
"""使用list来保存会话中的回合"""
|
||||||
|
|
||||||
create_timestamp = 0
|
create_timestamp = 0
|
||||||
|
"""会话创建时间"""
|
||||||
|
|
||||||
last_interact_timestamp = 0
|
last_interact_timestamp = 0
|
||||||
|
"""上次交互(产生回复)时间"""
|
||||||
|
|
||||||
just_switched_to_exist_session = False
|
just_switched_to_exist_session = False
|
||||||
|
|
||||||
@@ -116,7 +128,7 @@ class Session:
|
|||||||
logging.debug('{},lock release successfully,{}'.format(self.name, self.response_lock))
|
logging.debug('{},lock release successfully,{}'.format(self.name, self.response_lock))
|
||||||
|
|
||||||
# 从配置文件获取会话预设信息
|
# 从配置文件获取会话预设信息
|
||||||
def get_default_prompt(self, use_default: str=None):
|
def get_default_prompt(self, use_default: str = None):
|
||||||
config = pkg.utils.context.get_config()
|
config = pkg.utils.context.get_config()
|
||||||
|
|
||||||
import pkg.openai.dprompt as dprompt
|
import pkg.openai.dprompt as dprompt
|
||||||
@@ -130,7 +142,7 @@ class Session:
|
|||||||
{
|
{
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
'content': current_default_prompt
|
'content': current_default_prompt
|
||||||
},{
|
}, {
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
'content': 'ok'
|
'content': 'ok'
|
||||||
}
|
}
|
||||||
@@ -182,6 +194,8 @@ class Session:
|
|||||||
# 请求回复
|
# 请求回复
|
||||||
# 这个函数是阻塞的
|
# 这个函数是阻塞的
|
||||||
def append(self, text: str) -> str:
|
def append(self, text: str) -> str:
|
||||||
|
"""向session中添加一条消息,返回接口回复"""
|
||||||
|
|
||||||
self.last_interact_timestamp = int(time.time())
|
self.last_interact_timestamp = int(time.time())
|
||||||
|
|
||||||
# 触发插件事件
|
# 触发插件事件
|
||||||
@@ -215,14 +229,14 @@ class Session:
|
|||||||
res_ans = '\n\n'.join(res_ans_spt)
|
res_ans = '\n\n'.join(res_ans_spt)
|
||||||
|
|
||||||
# 将此次对话的双方内容加入到prompt中
|
# 将此次对话的双方内容加入到prompt中
|
||||||
self.prompt.append({'role':'user', 'content':text})
|
self.prompt.append({'role': 'user', 'content': text})
|
||||||
self.prompt.append({'role':'assistant', 'content':res_ans})
|
self.prompt.append({'role': 'assistant', 'content': res_ans})
|
||||||
|
|
||||||
if self.just_switched_to_exist_session:
|
if self.just_switched_to_exist_session:
|
||||||
self.just_switched_to_exist_session = False
|
self.just_switched_to_exist_session = False
|
||||||
self.set_ongoing()
|
self.set_ongoing()
|
||||||
|
|
||||||
return res_ans if res_ans[0]!='\n' else res_ans[1:]
|
return res_ans if res_ans[0] != '\n' else res_ans[1:]
|
||||||
|
|
||||||
# 删除上一回合并返回上一回合的问题
|
# 删除上一回合并返回上一回合的问题
|
||||||
def undo(self) -> str:
|
def undo(self) -> str:
|
||||||
@@ -231,10 +245,10 @@ class Session:
|
|||||||
# 删除最后两个消息
|
# 删除最后两个消息
|
||||||
if len(self.prompt) < 2:
|
if len(self.prompt) < 2:
|
||||||
raise Exception('之前无对话,无法撤销')
|
raise Exception('之前无对话,无法撤销')
|
||||||
|
|
||||||
question = self.prompt[-2]['content']
|
question = self.prompt[-2]['content']
|
||||||
self.prompt = self.prompt[:-2]
|
self.prompt = self.prompt[:-2]
|
||||||
|
|
||||||
# 返回上一回合的问题
|
# 返回上一回合的问题
|
||||||
return question
|
return question
|
||||||
|
|
||||||
@@ -242,13 +256,13 @@ class Session:
|
|||||||
def cut_out(self, msg: str, max_tokens: int) -> list:
|
def cut_out(self, msg: str, max_tokens: int) -> list:
|
||||||
"""将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens"""
|
"""将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens"""
|
||||||
# 如果用户消息长度超过max_tokens,直接返回
|
# 如果用户消息长度超过max_tokens,直接返回
|
||||||
|
|
||||||
temp_prompt = [
|
temp_prompt = [
|
||||||
{
|
{
|
||||||
'role': 'user',
|
'role': 'user',
|
||||||
'content': msg
|
'content': msg
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
token_count = len(msg)
|
token_count = len(msg)
|
||||||
# 倒序遍历prompt
|
# 倒序遍历prompt
|
||||||
|
|||||||
@@ -0,0 +1,4 @@
|
|||||||
|
"""插件支持包
|
||||||
|
|
||||||
|
包含插件基类、插件宿主以及部分API接口
|
||||||
|
"""
|
||||||
@@ -116,7 +116,9 @@ def initialize_plugins():
|
|||||||
|
|
||||||
|
|
||||||
def unload_plugins():
|
def unload_plugins():
|
||||||
""" 卸载插件 """
|
""" 卸载插件
|
||||||
|
"""
|
||||||
|
# 不再显式卸载插件,因为当程序结束时,插件的析构函数会被系统执行
|
||||||
# for plugin in __plugins__.values():
|
# for plugin in __plugins__.values():
|
||||||
# if plugin['enabled'] and plugin['instance'] is not None:
|
# if plugin['enabled'] and plugin['instance'] is not None:
|
||||||
# if not hasattr(plugin['instance'], '__del__'):
|
# if not hasattr(plugin['instance'], '__del__'):
|
||||||
|
|||||||
@@ -145,6 +145,7 @@ __current_registering_plugin__ = ""
|
|||||||
|
|
||||||
|
|
||||||
class Plugin:
|
class Plugin:
|
||||||
|
"""插件基类"""
|
||||||
|
|
||||||
host: host.PluginHost
|
host: host.PluginHost
|
||||||
"""插件宿主,提供插件的一些基础功能"""
|
"""插件宿主,提供插件的一些基础功能"""
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
import mirai.models.bus
|
import mirai.models.bus
|
||||||
from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
|
from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
|
||||||
@@ -20,13 +21,6 @@ import pkg.utils.context
|
|||||||
import pkg.plugin.host as plugin_host
|
import pkg.plugin.host as plugin_host
|
||||||
import pkg.plugin.models as plugin_models
|
import pkg.plugin.models as plugin_models
|
||||||
|
|
||||||
|
|
||||||
# 并行运行
|
|
||||||
def go(func, args=()):
|
|
||||||
thread = threading.Thread(target=func, args=args, daemon=True)
|
|
||||||
thread.start()
|
|
||||||
|
|
||||||
|
|
||||||
# 检查消息是否符合泛响应匹配机制
|
# 检查消息是否符合泛响应匹配机制
|
||||||
def check_response_rule(text: str):
|
def check_response_rule(text: str):
|
||||||
config = pkg.utils.context.get_config()
|
config = pkg.utils.context.get_config()
|
||||||
@@ -55,6 +49,9 @@ def check_response_rule(text: str):
|
|||||||
class QQBotManager:
|
class QQBotManager:
|
||||||
retry = 3
|
retry = 3
|
||||||
|
|
||||||
|
#线程池控制
|
||||||
|
pool = None
|
||||||
|
|
||||||
bot: Mirai = None
|
bot: Mirai = None
|
||||||
|
|
||||||
reply_filter = None
|
reply_filter = None
|
||||||
@@ -64,11 +61,14 @@ class QQBotManager:
|
|||||||
ban_person = []
|
ban_person = []
|
||||||
ban_group = []
|
ban_group = []
|
||||||
|
|
||||||
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, first_time_init=True):
|
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, pool_num: int = 10, first_time_init=True):
|
||||||
|
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.retry = retry
|
self.retry = retry
|
||||||
|
|
||||||
|
self.pool_num = pool_num
|
||||||
|
self.pool = ThreadPoolExecutor(max_workers=self.pool_num)
|
||||||
|
logging.debug("Registered thread pool Size:{}".format(pool_num))
|
||||||
|
|
||||||
# 加载禁用列表
|
# 加载禁用列表
|
||||||
if os.path.exists("banlist.py"):
|
if os.path.exists("banlist.py"):
|
||||||
import banlist
|
import banlist
|
||||||
@@ -116,7 +116,7 @@ class QQBotManager:
|
|||||||
|
|
||||||
self.on_person_message(event)
|
self.on_person_message(event)
|
||||||
|
|
||||||
go(friend_message_handler, (event,))
|
self.go(friend_message_handler, event)
|
||||||
|
|
||||||
@self.bot.on(StrangerMessage)
|
@self.bot.on(StrangerMessage)
|
||||||
async def on_stranger_message(event: StrangerMessage):
|
async def on_stranger_message(event: StrangerMessage):
|
||||||
@@ -136,7 +136,7 @@ class QQBotManager:
|
|||||||
|
|
||||||
self.on_person_message(event)
|
self.on_person_message(event)
|
||||||
|
|
||||||
go(stranger_message_handler, (event,))
|
self.go(stranger_message_handler, event)
|
||||||
|
|
||||||
@self.bot.on(GroupMessage)
|
@self.bot.on(GroupMessage)
|
||||||
async def on_group_message(event: GroupMessage):
|
async def on_group_message(event: GroupMessage):
|
||||||
@@ -156,7 +156,7 @@ class QQBotManager:
|
|||||||
|
|
||||||
self.on_group_message(event)
|
self.on_group_message(event)
|
||||||
|
|
||||||
go(group_message_handler, (event,))
|
self.go(group_message_handler, event)
|
||||||
|
|
||||||
def unsubscribe_all():
|
def unsubscribe_all():
|
||||||
"""取消所有订阅
|
"""取消所有订阅
|
||||||
@@ -173,6 +173,9 @@ class QQBotManager:
|
|||||||
|
|
||||||
self.unsubscribe_all = unsubscribe_all
|
self.unsubscribe_all = unsubscribe_all
|
||||||
|
|
||||||
|
def go(self, func, *args, **kwargs):
|
||||||
|
self.pool.submit(func, *args, **kwargs)
|
||||||
|
|
||||||
def first_time_init(self, mirai_http_api_config: dict):
|
def first_time_init(self, mirai_http_api_config: dict):
|
||||||
"""热重载后不再运行此函数"""
|
"""热重载后不再运行此函数"""
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -7,13 +7,15 @@ import pkg.utils.context
|
|||||||
import pkg.plugin.host
|
import pkg.plugin.host
|
||||||
|
|
||||||
|
|
||||||
def walk(module, prefix=''):
|
def walk(module, prefix='', path_prefix=''):
|
||||||
"""遍历并重载所有模块"""
|
"""遍历并重载所有模块"""
|
||||||
for item in pkgutil.iter_modules(module.__path__):
|
for item in pkgutil.iter_modules(module.__path__):
|
||||||
if item.ispkg:
|
if item.ispkg:
|
||||||
walk(__import__(module.__name__ + '.' + item.name, fromlist=['']), prefix + item.name + '.')
|
|
||||||
|
walk(__import__(module.__name__ + '.' + item.name, fromlist=['']), prefix + item.name + '.', path_prefix + item.name + '/')
|
||||||
else:
|
else:
|
||||||
logging.info('reload module: {}'.format(prefix + item.name))
|
logging.info('reload module: {}, path: {}'.format(prefix + item.name, path_prefix + item.name + '.py'))
|
||||||
|
pkg.plugin.host.__current_module_path__ = "plugins/" + path_prefix + item.name + '.py'
|
||||||
importlib.reload(__import__(module.__name__ + '.' + item.name, fromlist=['']))
|
importlib.reload(__import__(module.__name__ + '.' + item.name, fromlist=['']))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,37 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
import re
|
import re
|
||||||
import os
|
import os
|
||||||
|
import config
|
||||||
|
import traceback
|
||||||
|
|
||||||
text_render_font = ImageFont.truetype("res/simhei.ttf", 32, encoding="utf-8")
|
text_render_font: ImageFont = None
|
||||||
|
|
||||||
|
if hasattr(config, "blob_message_strategy") and config.blob_message_strategy == "image": # 仅在启用了image时才加载字体
|
||||||
|
use_font = config.font_path if hasattr(config, "font_path") else ""
|
||||||
|
try:
|
||||||
|
|
||||||
|
# 检查是否存在
|
||||||
|
if not os.path.exists(use_font):
|
||||||
|
# 若是windows系统,使用微软雅黑
|
||||||
|
if os.name == "nt":
|
||||||
|
use_font = "C:/Windows/Fonts/msyh.ttc"
|
||||||
|
if not os.path.exists(use_font):
|
||||||
|
logging.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。")
|
||||||
|
config.blob_message_strategy = "forward"
|
||||||
|
else:
|
||||||
|
logging.info("使用Windows自带字体:" + use_font)
|
||||||
|
text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8")
|
||||||
|
else:
|
||||||
|
logging.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。")
|
||||||
|
config.blob_message_strategy = "forward"
|
||||||
|
else:
|
||||||
|
text_render_font = ImageFont.truetype(use_font, 32, encoding="utf-8")
|
||||||
|
except:
|
||||||
|
traceback.print_exc()
|
||||||
|
logging.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font))
|
||||||
|
config.blob_message_strategy = "forward"
|
||||||
|
|
||||||
|
|
||||||
def indexNumber(path=''):
|
def indexNumber(path=''):
|
||||||
@@ -123,7 +152,7 @@ def text_to_image(text_str: str, save_as="temp.png", width=800):
|
|||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
# 准备画布
|
# 准备画布
|
||||||
img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 45)), (255, 255, 255, 255))
|
img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255))
|
||||||
draw = ImageDraw.Draw(img, mode='RGBA')
|
draw = ImageDraw.Draw(img, mode='RGBA')
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import os.path
|
|||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import pkg.utils.context
|
import pkg.utils.constants
|
||||||
|
|
||||||
|
|
||||||
def check_dulwich_closure():
|
def check_dulwich_closure():
|
||||||
@@ -46,7 +46,7 @@ def get_release_list() -> list:
|
|||||||
|
|
||||||
def get_current_tag() -> str:
|
def get_current_tag() -> str:
|
||||||
"""获取当前tag"""
|
"""获取当前tag"""
|
||||||
current_tag = "v0.1.0"
|
current_tag = pkg.utils.constants.semantic_version
|
||||||
if os.path.exists("current_tag"):
|
if os.path.exists("current_tag"):
|
||||||
with open("current_tag", "r") as f:
|
with open("current_tag", "r") as f:
|
||||||
current_tag = f.read()
|
current_tag = f.read()
|
||||||
@@ -54,7 +54,7 @@ def get_current_tag() -> str:
|
|||||||
return current_tag
|
return current_tag
|
||||||
|
|
||||||
|
|
||||||
def update_all() -> bool:
|
def update_all(cli: bool = False) -> bool:
|
||||||
"""检查更新并下载源码"""
|
"""检查更新并下载源码"""
|
||||||
current_tag = get_current_tag()
|
current_tag = get_current_tag()
|
||||||
|
|
||||||
@@ -69,12 +69,19 @@ def update_all() -> bool:
|
|||||||
|
|
||||||
if latest_rls == {}:
|
if latest_rls == {}:
|
||||||
latest_rls = rls
|
latest_rls = rls
|
||||||
logging.info("更新日志: {}".format(rls_notes))
|
if not cli:
|
||||||
|
logging.info("更新日志: {}".format(rls_notes))
|
||||||
|
else:
|
||||||
|
print("更新日志: {}".format(rls_notes))
|
||||||
|
|
||||||
if latest_rls == {}: # 没有新版本
|
if latest_rls == {}: # 没有新版本
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 下载最新版本的zip到temp目录
|
# 下载最新版本的zip到temp目录
|
||||||
logging.info("开始下载最新版本: {}".format(latest_rls['zipball_url']))
|
if not cli:
|
||||||
|
logging.info("开始下载最新版本: {}".format(latest_rls['zipball_url']))
|
||||||
|
else:
|
||||||
|
print("开始下载最新版本: {}".format(latest_rls['zipball_url']))
|
||||||
zip_url = latest_rls['zipball_url']
|
zip_url = latest_rls['zipball_url']
|
||||||
zip_resp = requests.get(url=zip_url)
|
zip_resp = requests.get(url=zip_url)
|
||||||
zip_data = zip_resp.content
|
zip_data = zip_resp.content
|
||||||
@@ -87,7 +94,10 @@ def update_all() -> bool:
|
|||||||
with open("temp/updater/{}.zip".format(latest_rls['tag_name']), "wb") as f:
|
with open("temp/updater/{}.zip".format(latest_rls['tag_name']), "wb") as f:
|
||||||
f.write(zip_data)
|
f.write(zip_data)
|
||||||
|
|
||||||
logging.info("下载最新版本完成: {}".format("temp/updater/{}.zip".format(latest_rls['tag_name'])))
|
if not cli:
|
||||||
|
logging.info("下载最新版本完成: {}".format("temp/updater/{}.zip".format(latest_rls['tag_name'])))
|
||||||
|
else:
|
||||||
|
print("下载最新版本完成: {}".format("temp/updater/{}.zip".format(latest_rls['tag_name'])))
|
||||||
|
|
||||||
# 解压zip到temp/updater/<tag_name>/
|
# 解压zip到temp/updater/<tag_name>/
|
||||||
import zipfile
|
import zipfile
|
||||||
@@ -124,8 +134,11 @@ def update_all() -> bool:
|
|||||||
f.write(current_tag)
|
f.write(current_tag)
|
||||||
|
|
||||||
# 通知管理员
|
# 通知管理员
|
||||||
import pkg.utils.context
|
if not cli:
|
||||||
pkg.utils.context.get_qqbot_manager().notify_admin("已更新到最新版本: {}\n更新日志:\n{}\n新功能通常可以在config-template.py中看到,完整的更新日志请前往 https://github.com/RockChinQ/QChatGPT/releases 查看".format(current_tag, "\n".join(rls_notes)))
|
import pkg.utils.context
|
||||||
|
pkg.utils.context.get_qqbot_manager().notify_admin("已更新到最新版本: {}\n更新日志:\n{}\n新功能通常可以在config-template.py中看到,完整的更新日志请前往 https://github.com/RockChinQ/QChatGPT/releases 查看".format(current_tag, "\n".join(rls_notes)))
|
||||||
|
else:
|
||||||
|
print("已更新到最新版本: {}\n更新日志:\n{}\n新功能通常可以在config-template.py中看到,完整的更新日志请前往 https://github.com/RockChinQ/QChatGPT/releases 查看".format(current_tag, "\n".join(rls_notes)))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@@ -152,24 +165,12 @@ def get_remote_url(repo_path: str) -> str:
|
|||||||
|
|
||||||
def get_current_version_info() -> str:
|
def get_current_version_info() -> str:
|
||||||
"""获取当前版本信息"""
|
"""获取当前版本信息"""
|
||||||
check_dulwich_closure()
|
rls_list = get_release_list()
|
||||||
|
current_tag = get_current_tag()
|
||||||
from dulwich import porcelain
|
for rls in rls_list:
|
||||||
|
if rls['tag_name'] == current_tag:
|
||||||
repo = porcelain.open_repo('.')
|
return rls['name'] + "\n" + rls['body']
|
||||||
|
return "未知版本"
|
||||||
version_str = ""
|
|
||||||
|
|
||||||
for entry in repo.get_walker():
|
|
||||||
version_str += "提交编号: "+str(entry.commit.id)[2:9] + "\n"
|
|
||||||
tz = datetime.timezone(datetime.timedelta(hours=entry.commit.commit_timezone // 3600))
|
|
||||||
dt = datetime.datetime.fromtimestamp(entry.commit.commit_time, tz)
|
|
||||||
version_str += "时间: "+dt.strftime('%m-%d %H:%M:%S') + "\n"
|
|
||||||
version_str += "说明: "+str(entry.commit.message, encoding="utf-8").strip() + "\n"
|
|
||||||
version_str += "提交作者: '" + str(entry.commit.author)[2:-1] + "'"
|
|
||||||
break
|
|
||||||
|
|
||||||
return version_str
|
|
||||||
|
|
||||||
|
|
||||||
def get_commit_id_and_time_and_msg() -> str:
|
def get_commit_id_and_time_and_msg() -> str:
|
||||||
|
|||||||
BIN
res/simhei.ttf
BIN
res/simhei.ttf
Binary file not shown.
Reference in New Issue
Block a user