mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 03:55:55 +00:00
@@ -72,6 +72,12 @@
|
||||
- 现已支持OpenAI的对话`Completion API`和绘图`Image API`
|
||||
- 向机器人发送指令`!draw <prompt>`即可使用绘图模型
|
||||
</details>
|
||||
<details>
|
||||
<summary>✅支持指令控制热重载、热更新</summary>
|
||||
|
||||
- 允许在运行期间修改`config.py`或其他代码后,以管理员账号向机器人发送指令`!reload`进行热重载,无需重启
|
||||
- 运行期间允许以管理员账号向机器人发送指令`!update`进行热更新,拉取远程最新代码并执行热重载
|
||||
</details>
|
||||
|
||||
## 💻技术栈
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ import logging
|
||||
# port: 运行mirai的主机端口
|
||||
# verifyKey: mirai-api-http的verifyKey
|
||||
# qq: 机器人的QQ号
|
||||
#
|
||||
# 注意: QQ机器人配置不支持热重载及热更新
|
||||
mirai_http_api_config = {
|
||||
"adapter": "WebSocketAdapter",
|
||||
"host": "localhost",
|
||||
@@ -30,6 +32,9 @@ openai_config = {
|
||||
},
|
||||
}
|
||||
|
||||
# 管理员QQ号,用于接收报错等通知及执行管理员级别指令,为0时关闭此功能
|
||||
admin_qq = 0
|
||||
|
||||
# 情景预设(机器人人格)
|
||||
# 每个会话的预设信息,影响所有会话,无视指令重置
|
||||
# 可以通过这个字段指定某些情况的回复,可直接用自然语言描述指令
|
||||
@@ -38,9 +43,6 @@ openai_config = {
|
||||
# 可参考 https://github.com/PlexPt/awesome-chatgpt-prompts-zh
|
||||
default_prompt = "如果我之后想获取帮助,请你说“输入!help获取帮助”"
|
||||
|
||||
# 管理员QQ号,用于接收报错等通知,为0时不发送通知
|
||||
admin_qq = 0
|
||||
|
||||
# 群内响应规则
|
||||
# 符合此消息的群内消息即使不包含at机器人也会响应
|
||||
# 支持消息前缀匹配及正则表达式匹配
|
||||
|
||||
59
main.py
59
main.py
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
@@ -7,6 +8,8 @@ import time
|
||||
import logging
|
||||
import colorlog
|
||||
|
||||
from mirai.bot import MiraiRunner
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.append(".")
|
||||
@@ -27,10 +30,18 @@ def init_db():
|
||||
database.initialize_database()
|
||||
|
||||
|
||||
def main():
|
||||
def main(first_time_init=False):
|
||||
# 导入config.py
|
||||
assert os.path.exists('config.py')
|
||||
|
||||
# 检查是否设置了管理员
|
||||
import config
|
||||
if not (hasattr(config, 'admin_qq') and config.admin_qq != 0):
|
||||
logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
|
||||
|
||||
import pkg.utils.context
|
||||
if pkg.utils.context.context['logger_handler'] is not None:
|
||||
logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler'])
|
||||
|
||||
logging.basicConfig(level=config.logging_level, # 设置日志输出格式
|
||||
filename='qchatgpt.log', # log日志输出的文件位置和文件名
|
||||
@@ -54,6 +65,7 @@ def main():
|
||||
import pkg.openai.session
|
||||
import pkg.qqbot.manager
|
||||
|
||||
pkg.utils.context.context['logger_handler'] = sh
|
||||
# 主启动流程
|
||||
database = pkg.database.manager.DatabaseManager()
|
||||
|
||||
@@ -66,29 +78,46 @@ def main():
|
||||
|
||||
# 初始化qq机器人
|
||||
qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config,
|
||||
timeout=config.process_message_timeout, retry=config.retry_times)
|
||||
timeout=config.process_message_timeout, retry=config.retry_times,
|
||||
first_time_init=first_time_init)
|
||||
|
||||
qq_bot_thread = threading.Thread(target=qqbot.bot.run, args=(), daemon=True)
|
||||
qq_bot_thread.start()
|
||||
if first_time_init: # 不是热重载之后的启动,则不启动新的bot线程
|
||||
qq_bot_thread = threading.Thread(target=qqbot.bot.run, args=(), daemon=True)
|
||||
qq_bot_thread.start()
|
||||
|
||||
logging.info('程序启动完成')
|
||||
time.sleep(2)
|
||||
logging.info('程序启动完成,如长时间未显示 ”成功登录到账号xxxxx“ ,并且不回复消息,请查看 https://github.com/RockChinQ/QChatGPT/issues/37')
|
||||
|
||||
while True:
|
||||
try:
|
||||
time.sleep(86400)
|
||||
time.sleep(10000)
|
||||
if qqbot != pkg.utils.context.get_qqbot_manager(): # 已经reload了
|
||||
logging.info("以前的main流程由于reload退出")
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
try:
|
||||
pkg.openai.manager.get_inst().key_mgr.dump_fee()
|
||||
for session in pkg.openai.session.sessions:
|
||||
logging.info('持久化session: %s', session)
|
||||
pkg.openai.session.sessions[session].persistence()
|
||||
except Exception as e:
|
||||
if not isinstance(e, KeyboardInterrupt):
|
||||
raise e
|
||||
stop()
|
||||
|
||||
print("程序退出")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def stop():
|
||||
import pkg.utils.context
|
||||
import pkg.qqbot.manager
|
||||
import pkg.openai.session
|
||||
try:
|
||||
qqbot_inst = pkg.utils.context.get_qqbot_manager()
|
||||
assert isinstance(qqbot_inst, pkg.qqbot.manager.QQBotManager)
|
||||
|
||||
pkg.utils.context.get_openai_manager().key_mgr.dump_fee()
|
||||
for session in pkg.openai.session.sessions:
|
||||
logging.info('持久化session: %s', session)
|
||||
pkg.openai.session.sessions[session].persistence()
|
||||
except Exception as e:
|
||||
if not isinstance(e, KeyboardInterrupt):
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 检查是否有config.py,如果没有就把config-template.py复制一份,并退出程序
|
||||
if not os.path.exists('config.py'):
|
||||
@@ -109,4 +138,4 @@ if __name__ == '__main__':
|
||||
print("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77")
|
||||
sys.exit(0)
|
||||
|
||||
main()
|
||||
main(True)
|
||||
|
||||
@@ -6,8 +6,7 @@ from sqlite3 import Cursor
|
||||
import sqlite3
|
||||
|
||||
import config
|
||||
|
||||
inst = None
|
||||
import pkg.utils.context
|
||||
|
||||
|
||||
# 数据库管理
|
||||
@@ -20,8 +19,7 @@ class DatabaseManager:
|
||||
|
||||
self.reconnect()
|
||||
|
||||
global inst
|
||||
inst = self
|
||||
pkg.utils.context.set_database_manager(self)
|
||||
|
||||
# 连接到数据库文件
|
||||
def reconnect(self):
|
||||
@@ -312,6 +310,3 @@ class DatabaseManager:
|
||||
fee[key_md5] = fee_count
|
||||
return fee
|
||||
|
||||
def get_inst() -> DatabaseManager:
|
||||
global inst
|
||||
return inst
|
||||
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
|
||||
import pkg.database.manager
|
||||
import pkg.qqbot.manager
|
||||
import pkg.utils.context
|
||||
import config
|
||||
|
||||
|
||||
@@ -62,43 +63,6 @@ class KeysManager:
|
||||
def add(self, key_name, key):
|
||||
self.api_key[key_name] = key
|
||||
|
||||
# def get_usage(self, api_key):
|
||||
# md5 = hashlib.md5(api_key.encode('utf-8')).hexdigest()
|
||||
# if md5 not in self.usage:
|
||||
# self.usage[md5] = 0
|
||||
# return self.usage[md5]
|
||||
|
||||
# 报告使用
|
||||
# 返回是否需要将openai的api-key切换
|
||||
# def report_usage(self, new_content: str) -> bool:
|
||||
# md5 = hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
|
||||
# if md5 not in self.usage:
|
||||
# self.usage[md5] = 0
|
||||
#
|
||||
# # 经测算得出的理论与实际的偏差比例
|
||||
# salt_rate = 0.91
|
||||
#
|
||||
# self.usage[md5] += ( (len(new_content.encode('utf-8')) - len(new_content)) / 2 + len(new_content) )*salt_rate
|
||||
#
|
||||
# self.usage[md5] = int(self.usage[md5])
|
||||
#
|
||||
# if self.usage[md5] >= self.api_key_usage_threshold:
|
||||
# switch_result, key_name = self.auto_switch()
|
||||
#
|
||||
# # 检查是否切换到新的
|
||||
# if switch_result:
|
||||
# if key_name not in self.alerted:
|
||||
# # 通知管理员
|
||||
# pkg.qqbot.manager.get_inst().notify_admin("api-key已切换到:" + key_name)
|
||||
# self.alerted.append(key_name)
|
||||
# return True
|
||||
# else:
|
||||
# if key_name not in self.alerted:
|
||||
# # 通知管理员
|
||||
# pkg.qqbot.manager.get_inst().notify_admin("api-key已用完,无未使用的api-key可供切换")
|
||||
# self.alerted.append(key_name)
|
||||
# return False
|
||||
|
||||
# 设置当前使用的api-key使用量超限
|
||||
# 这是在尝试调用api时发生超限异常时调用的
|
||||
def set_current_exceeded(self):
|
||||
@@ -107,14 +71,6 @@ class KeysManager:
|
||||
self.fee[md5] = self.api_key_fee_threshold
|
||||
self.dump_fee()
|
||||
|
||||
# def dump_usage(self):
|
||||
# pkg.database.manager.get_inst().dump_api_key_usage(api_keys=self.api_key, usage=self.usage)
|
||||
|
||||
# def load_usage(self):
|
||||
# self.usage = pkg.database.manager.get_inst().load_api_key_usage()
|
||||
# logging.debug("load usage:" + str(self.usage))
|
||||
# print("load usage:" + str(self.usage))
|
||||
|
||||
def get_fee(self, api_key):
|
||||
md5 = hashlib.md5(api_key.encode('utf-8')).hexdigest()
|
||||
if md5 not in self.fee:
|
||||
@@ -135,19 +91,19 @@ class KeysManager:
|
||||
if switch_result:
|
||||
if key_name not in self.alerted:
|
||||
# 通知管理员
|
||||
pkg.qqbot.manager.get_inst().notify_admin("api-key已切换到:" + key_name)
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("api-key已切换到:" + key_name)
|
||||
self.alerted.append(key_name)
|
||||
return True
|
||||
else:
|
||||
if key_name not in self.alerted:
|
||||
# 通知管理员
|
||||
pkg.qqbot.manager.get_inst().notify_admin("api-key已用完,无未使用的api-key可供切换")
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("api-key已用完,无未使用的api-key可供切换")
|
||||
self.alerted.append(key_name)
|
||||
return False
|
||||
|
||||
def dump_fee(self):
|
||||
pkg.database.manager.get_inst().dump_api_key_fee(api_keys=self.api_key, fee=self.fee)
|
||||
pkg.utils.context.get_database_manager().dump_api_key_fee(api_keys=self.api_key, fee=self.fee)
|
||||
|
||||
def load_fee(self):
|
||||
self.fee = pkg.database.manager.get_inst().load_api_key_fee()
|
||||
logging.info("load fee:" + str(self.fee))
|
||||
self.fee = pkg.utils.context.get_database_manager().load_api_key_fee()
|
||||
logging.info("load fee:" + str(self.fee))
|
||||
|
||||
@@ -6,8 +6,7 @@ import config
|
||||
|
||||
import pkg.openai.keymgr
|
||||
import pkg.openai.pricing as pricing
|
||||
|
||||
inst = None
|
||||
import pkg.utils.context
|
||||
|
||||
|
||||
# 为其他模块提供与OpenAI交互的接口
|
||||
@@ -27,11 +26,11 @@ class OpenAIInteract:
|
||||
|
||||
openai.api_key = self.key_mgr.get_using_key()
|
||||
|
||||
global inst
|
||||
inst = self
|
||||
pkg.utils.context.set_openai_manager(self)
|
||||
|
||||
# 请求OpenAI Completion
|
||||
def request_completion(self, prompt, stop):
|
||||
# print("request")
|
||||
response = openai.Completion.create(
|
||||
prompt=prompt,
|
||||
stop=stop,
|
||||
@@ -41,7 +40,6 @@ class OpenAIInteract:
|
||||
|
||||
switched = self.key_mgr.report_fee(pricing.language_base_price(config.completion_api_params['model'],
|
||||
prompt + response['choices'][0]['text']))
|
||||
|
||||
if switched:
|
||||
openai.api_key = self.key_mgr.get_using_key()
|
||||
|
||||
@@ -64,7 +62,3 @@ class OpenAIInteract:
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def get_inst() -> OpenAIInteract:
|
||||
global inst
|
||||
return inst
|
||||
|
||||
34
pkg/openai/modelmgr.py
Normal file
34
pkg/openai/modelmgr.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# 提供与模型交互的抽象接口
|
||||
|
||||
COMPLETION_MODELS = {
|
||||
'text-davinci-003'
|
||||
}
|
||||
|
||||
EDIT_MODELS = {
|
||||
|
||||
}
|
||||
|
||||
IMAGE_MODELS = {
|
||||
|
||||
}
|
||||
|
||||
|
||||
# ModelManager
|
||||
# 由session包含
|
||||
class ModelMgr(object):
|
||||
|
||||
using_completion_model = ""
|
||||
using_edit_model = ""
|
||||
using_image_model = ""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_using_completion_model(self):
|
||||
return self.using_completion_model
|
||||
|
||||
def get_using_edit_model(self):
|
||||
return self.using_edit_model
|
||||
|
||||
def get_using_image_model(self):
|
||||
return self.using_image_model
|
||||
@@ -5,6 +5,7 @@ import time
|
||||
import config
|
||||
import pkg.openai.manager
|
||||
import pkg.database.manager
|
||||
import pkg.utils.context
|
||||
|
||||
# 运行时保存的所有session
|
||||
sessions = {}
|
||||
@@ -19,7 +20,7 @@ class SessionOfflineStatus:
|
||||
def load_sessions():
|
||||
global sessions
|
||||
|
||||
db_inst = pkg.database.manager.get_inst()
|
||||
db_inst = pkg.utils.context.get_database_manager()
|
||||
|
||||
session_data = db_inst.load_valid_sessions()
|
||||
|
||||
@@ -147,10 +148,11 @@ class Session:
|
||||
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
|
||||
|
||||
# 向API请求补全
|
||||
response = pkg.openai.manager.get_inst().request_completion(self.cut_out(self.prompt + self.user_name + ':' +
|
||||
text + '\n' + self.bot_name + ':',
|
||||
max_rounds, max_length),
|
||||
self.user_name + ':')
|
||||
response = pkg.utils.context.get_openai_manager().request_completion(
|
||||
self.cut_out(self.prompt + self.user_name + ':' +
|
||||
text + '\n' + self.bot_name + ':',
|
||||
max_rounds, max_length),
|
||||
self.user_name + ':')
|
||||
|
||||
self.prompt += self.user_name + ':' + text + '\n' + self.bot_name + ':'
|
||||
# print(response)
|
||||
@@ -202,7 +204,7 @@ class Session:
|
||||
if self.prompt == get_default_prompt():
|
||||
return
|
||||
|
||||
db_inst = pkg.database.manager.get_inst()
|
||||
db_inst = pkg.utils.context.get_database_manager()
|
||||
|
||||
name_spt = self.name.split('_')
|
||||
|
||||
@@ -217,10 +219,10 @@ class Session:
|
||||
if self.prompt != get_default_prompt():
|
||||
self.persistence()
|
||||
if explicit:
|
||||
pkg.database.manager.get_inst().explicit_close_session(self.name, self.create_timestamp)
|
||||
pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
|
||||
|
||||
if expired:
|
||||
pkg.database.manager.get_inst().set_session_expired(self.name, self.create_timestamp)
|
||||
pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
|
||||
self.prompt = get_default_prompt()
|
||||
self.create_timestamp = int(time.time())
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
@@ -233,11 +235,11 @@ class Session:
|
||||
|
||||
# 将本session的数据库状态设置为on_going
|
||||
def set_ongoing(self):
|
||||
pkg.database.manager.get_inst().set_session_ongoing(self.name, self.create_timestamp)
|
||||
pkg.utils.context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp)
|
||||
|
||||
# 切换到上一个session
|
||||
def last_session(self):
|
||||
last_one = pkg.database.manager.get_inst().last_session(self.name, self.last_interact_timestamp)
|
||||
last_one = pkg.utils.context.get_database_manager().last_session(self.name, self.last_interact_timestamp)
|
||||
if last_one is None:
|
||||
return None
|
||||
else:
|
||||
@@ -252,7 +254,7 @@ class Session:
|
||||
|
||||
# 切换到下一个session
|
||||
def next_session(self):
|
||||
next_one = pkg.database.manager.get_inst().next_session(self.name, self.last_interact_timestamp)
|
||||
next_one = pkg.utils.context.get_database_manager().next_session(self.name, self.last_interact_timestamp)
|
||||
if next_one is None:
|
||||
return None
|
||||
else:
|
||||
@@ -266,8 +268,8 @@ class Session:
|
||||
return self
|
||||
|
||||
def list_history(self, capacity: int = 10, page: int = 0):
|
||||
return pkg.database.manager.get_inst().list_history(self.name, capacity, page,
|
||||
get_default_prompt())
|
||||
return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page,
|
||||
get_default_prompt())
|
||||
|
||||
def draw_image(self, prompt: str):
|
||||
return pkg.openai.manager.get_inst().request_image(prompt)
|
||||
return pkg.utils.context.get_openai_manager().request_image(prompt)
|
||||
|
||||
@@ -3,10 +3,13 @@ import json
|
||||
import os
|
||||
import threading
|
||||
|
||||
import mirai.models.bus
|
||||
import openai.error
|
||||
from mirai import At, GroupMessage, MessageEvent, Mirai, Plain, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
|
||||
FriendMessage, Image
|
||||
|
||||
from mirai.models.bus import ModelEventBus
|
||||
|
||||
from mirai.models.message import Quote
|
||||
|
||||
import config
|
||||
@@ -17,8 +20,7 @@ import logging
|
||||
|
||||
import pkg.qqbot.filter
|
||||
import pkg.qqbot.process as processor
|
||||
|
||||
inst = None
|
||||
import pkg.utils.context
|
||||
|
||||
|
||||
# 并行运行
|
||||
@@ -58,7 +60,7 @@ class QQBotManager:
|
||||
|
||||
reply_filter = None
|
||||
|
||||
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3):
|
||||
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, first_time_init=True):
|
||||
|
||||
self.timeout = timeout
|
||||
self.retry = retry
|
||||
@@ -71,6 +73,47 @@ class QQBotManager:
|
||||
else:
|
||||
self.reply_filter = pkg.qqbot.filter.ReplyFilter([])
|
||||
|
||||
# 由于YiriMirai的bot对象是单例的,且shutdown方法暂时无法使用
|
||||
# 故只在第一次初始化时创建bot对象,重载之后使用原bot对象
|
||||
# 因此,bot的配置不支持热重载
|
||||
if first_time_init:
|
||||
self.first_time_init(mirai_http_api_config)
|
||||
else:
|
||||
self.bot = pkg.utils.context.get_qqbot_manager().bot
|
||||
|
||||
pkg.utils.context.set_qqbot_manager(self)
|
||||
|
||||
# Caution: 注册新的事件处理器之后,请务必在unsubscribe_all中编写相应的取消订阅代码
|
||||
@self.bot.on(FriendMessage)
|
||||
async def on_friend_message(event: FriendMessage):
|
||||
go(self.on_person_message, (event,))
|
||||
|
||||
@self.bot.on(StrangerMessage)
|
||||
async def on_stranger_message(event: StrangerMessage):
|
||||
go(self.on_person_message, (event,))
|
||||
|
||||
@self.bot.on(GroupMessage)
|
||||
async def on_group_message(event: GroupMessage):
|
||||
go(self.on_group_message, (event,))
|
||||
|
||||
def unsubscribe_all():
|
||||
"""取消所有订阅
|
||||
|
||||
用于在热重载流程中卸载所有事件处理器
|
||||
"""
|
||||
assert isinstance(self.bot, Mirai)
|
||||
bus = self.bot.bus
|
||||
assert isinstance(bus, mirai.models.bus.ModelEventBus)
|
||||
|
||||
bus.unsubscribe(FriendMessage, on_friend_message)
|
||||
bus.unsubscribe(StrangerMessage, on_stranger_message)
|
||||
bus.unsubscribe(GroupMessage, on_group_message)
|
||||
|
||||
self.unsubscribe_all = unsubscribe_all
|
||||
|
||||
def first_time_init(self, mirai_http_api_config: dict):
|
||||
"""热重载后不再运行此函数"""
|
||||
|
||||
if 'adapter' not in mirai_http_api_config or mirai_http_api_config['adapter'] == "WebSocketAdapter":
|
||||
bot = Mirai(
|
||||
qq=mirai_http_api_config['qq'],
|
||||
@@ -93,23 +136,8 @@ class QQBotManager:
|
||||
else:
|
||||
raise Exception("未知的适配器类型")
|
||||
|
||||
@bot.on(FriendMessage)
|
||||
async def on_friend_message(event: FriendMessage):
|
||||
go(self.on_person_message, (event,))
|
||||
|
||||
@bot.on(StrangerMessage)
|
||||
async def on_stranger_message(event: StrangerMessage):
|
||||
go(self.on_person_message, (event,))
|
||||
|
||||
@bot.on(GroupMessage)
|
||||
async def on_group_message(event: GroupMessage):
|
||||
go(self.on_group_message, (event,))
|
||||
|
||||
self.bot = bot
|
||||
|
||||
global inst
|
||||
inst = self
|
||||
|
||||
def send(self, event, msg, check_quote=True):
|
||||
asyncio.run(
|
||||
self.bot.send(event, msg, quote=True if hasattr(config,
|
||||
@@ -117,7 +145,6 @@ class QQBotManager:
|
||||
|
||||
# 私聊消息处理
|
||||
def on_person_message(self, event: MessageEvent):
|
||||
|
||||
reply = ''
|
||||
|
||||
if event.sender.id == self.bot.qq:
|
||||
@@ -167,11 +194,13 @@ class QQBotManager:
|
||||
event.sender.id)
|
||||
break
|
||||
except FunctionTimedOut:
|
||||
pkg.openai.session.get_session('group_{}'.format(event.group.id)).release_response_lock()
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
if failed == self.retry:
|
||||
self.notify_admin("{} 请求超时".format("group_{}".format(event.sender.id)))
|
||||
pkg.openai.session.get_session('group_{}'.format(event.group.id)).release_response_lock()
|
||||
self.notify_admin("{} 请求超时".format("group_{}".format(event.group.id)))
|
||||
replys = ["[bot]err:请求超时"]
|
||||
|
||||
return replys
|
||||
@@ -196,8 +225,3 @@ class QQBotManager:
|
||||
logging.info("通知管理员:{}".format(message))
|
||||
send_task = self.bot.send_friend_message(config.admin_qq, "[bot]{}".format(message))
|
||||
threading.Thread(target=asyncio.run, args=(send_task,)).start()
|
||||
|
||||
|
||||
def get_inst() -> QQBotManager:
|
||||
global inst
|
||||
return inst
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# 此模块提供了消息处理的具体逻辑的接口
|
||||
import asyncio
|
||||
import datetime
|
||||
import threading
|
||||
|
||||
import pkg.qqbot.manager as manager
|
||||
from func_timeout import func_set_timeout
|
||||
@@ -8,12 +9,14 @@ import logging
|
||||
import openai
|
||||
|
||||
from mirai import Image, MessageChain
|
||||
from mirai.models.message import Quote
|
||||
|
||||
import config
|
||||
|
||||
import pkg.openai.session
|
||||
import pkg.openai.manager
|
||||
import pkg.utils.reloader
|
||||
import pkg.utils.updater
|
||||
import pkg.utils.context
|
||||
|
||||
processing = []
|
||||
|
||||
@@ -23,7 +26,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
||||
sender_id: int) -> MessageChain:
|
||||
global processing
|
||||
|
||||
mgr = pkg.qqbot.manager.get_inst()
|
||||
mgr = pkg.utils.context.get_qqbot_manager()
|
||||
|
||||
reply = []
|
||||
session_name = "{}_{}".format(launcher_type, launcher_id)
|
||||
@@ -123,22 +126,22 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
||||
|
||||
reply = [reply_str]
|
||||
elif cmd == 'usage':
|
||||
api_keys = pkg.openai.manager.get_inst().key_mgr.api_key
|
||||
api_keys = pkg.utils.context.get_openai_manager().key_mgr.api_key
|
||||
reply_str = "[bot]api-key使用情况:(阈值:{})\n\n".format(
|
||||
pkg.openai.manager.get_inst().key_mgr.api_key_fee_threshold)
|
||||
pkg.utils.context.get_openai_manager().key_mgr.api_key_fee_threshold)
|
||||
|
||||
using_key_name = ""
|
||||
for api_key in api_keys:
|
||||
reply_str += "{}:\n - {}美元 {}%\n".format(api_key,
|
||||
round(
|
||||
pkg.openai.manager.get_inst().key_mgr.get_fee(
|
||||
pkg.utils.context.get_openai_manager().key_mgr.get_fee(
|
||||
api_keys[api_key]), 6),
|
||||
round(
|
||||
pkg.openai.manager.get_inst().key_mgr.get_fee(
|
||||
pkg.utils.context.get_openai_manager().key_mgr.get_fee(
|
||||
api_keys[
|
||||
api_key]) / pkg.openai.manager.get_inst().key_mgr.api_key_fee_threshold * 100,
|
||||
api_key]) / pkg.utils.context.get_openai_manager().key_mgr.api_key_fee_threshold * 100,
|
||||
3))
|
||||
if api_keys[api_key] == pkg.openai.manager.get_inst().key_mgr.using_key:
|
||||
if api_keys[api_key] == pkg.utils.context.get_openai_manager().key_mgr.using_key:
|
||||
using_key_name = api_key
|
||||
reply_str += "\n当前使用:{}".format(using_key_name)
|
||||
|
||||
@@ -157,6 +160,23 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
||||
if not (hasattr(config, 'include_image_description')
|
||||
and not config.include_image_description):
|
||||
reply.append(" ".join(params))
|
||||
elif cmd == 'reload' and launcher_type == 'person' and launcher_id == config.admin_qq:
|
||||
def reload_task():
|
||||
pkg.utils.reloader.reload_all()
|
||||
|
||||
threading.Thread(target=reload_task, daemon=True).start()
|
||||
elif cmd == 'update' and launcher_type == 'person' and launcher_id == config.admin_qq:
|
||||
def update_task():
|
||||
try:
|
||||
pkg.utils.updater.update_all()
|
||||
except Exception as e0:
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("更新失败:{}".format(e0))
|
||||
return
|
||||
pkg.utils.reloader.reload_all()
|
||||
|
||||
threading.Thread(target=update_task, daemon=True).start()
|
||||
else:
|
||||
reply = ["[bot]err:未知的指令或权限不足: "+cmd]
|
||||
except Exception as e:
|
||||
mgr.notify_admin("{}指令执行失败:{}".format(session_name, e))
|
||||
logging.exception(e)
|
||||
@@ -174,17 +194,17 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
||||
reply = ["[bot]err:调用API失败,请重试或联系作者,或等待修复"]
|
||||
except openai.error.RateLimitError as e:
|
||||
# 尝试切换api-key
|
||||
current_tokens_amt = pkg.openai.manager.get_inst().key_mgr.get_fee(
|
||||
pkg.openai.manager.get_inst().key_mgr.get_using_key())
|
||||
pkg.openai.manager.get_inst().key_mgr.set_current_exceeded()
|
||||
switched, name = pkg.openai.manager.get_inst().key_mgr.auto_switch()
|
||||
current_tokens_amt = pkg.utils.context.get_openai_manager().key_mgr.get_fee(
|
||||
pkg.utils.context.get_openai_manager().key_mgr.get_using_key())
|
||||
pkg.utils.context.get_openai_manager().key_mgr.set_current_exceeded()
|
||||
switched, name = pkg.utils.context.get_openai_manager().key_mgr.auto_switch()
|
||||
|
||||
if not switched:
|
||||
mgr.notify_admin("API调用额度超限({}),请向OpenAI账户充值或在config.py中更换api_key".format(
|
||||
current_tokens_amt))
|
||||
reply = ["[bot]err:API调用额度超额,请联系作者,或等待修复"]
|
||||
else:
|
||||
openai.api_key = pkg.openai.manager.get_inst().key_mgr.get_using_key()
|
||||
openai.api_key = pkg.utils.context.get_openai_manager().key_mgr.get_using_key()
|
||||
mgr.notify_admin("API调用额度超限({}),已切换到{}".format(current_tokens_amt, name))
|
||||
reply = ["[bot]err:API调用额度超额,已自动切换,请重新发送消息"]
|
||||
except openai.error.InvalidRequestError as e:
|
||||
|
||||
0
pkg/utils/__init__.py
Normal file
0
pkg/utils/__init__.py
Normal file
32
pkg/utils/context.py
Normal file
32
pkg/utils/context.py
Normal file
@@ -0,0 +1,32 @@
|
||||
context = {
|
||||
'inst': {
|
||||
'database.manager.DatabaseManager': None,
|
||||
'openai.manager.OpenAIInteract': None,
|
||||
'qqbot.manager.QQBotManager': None,
|
||||
},
|
||||
'logger_handler': None,
|
||||
}
|
||||
|
||||
|
||||
def set_database_manager(inst):
|
||||
context['inst']['database.manager.DatabaseManager'] = inst
|
||||
|
||||
|
||||
def get_database_manager():
|
||||
return context['inst']['database.manager.DatabaseManager']
|
||||
|
||||
|
||||
def set_openai_manager(inst):
|
||||
context['inst']['openai.manager.OpenAIInteract'] = inst
|
||||
|
||||
|
||||
def get_openai_manager():
|
||||
return context['inst']['openai.manager.OpenAIInteract']
|
||||
|
||||
|
||||
def set_qqbot_manager(inst):
|
||||
context['inst']['qqbot.manager.QQBotManager'] = inst
|
||||
|
||||
|
||||
def get_qqbot_manager():
|
||||
return context['inst']['qqbot.manager.QQBotManager']
|
||||
45
pkg/utils/reloader.py
Normal file
45
pkg/utils/reloader.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
import colorlog
|
||||
|
||||
import pkg
|
||||
import importlib
|
||||
import pkgutil
|
||||
import pkg.utils.context
|
||||
from main import log_colors_config
|
||||
|
||||
|
||||
def walk(module, prefix=''):
|
||||
"""遍历并重载所有模块"""
|
||||
for item in pkgutil.iter_modules(module.__path__):
|
||||
if item.ispkg:
|
||||
walk(__import__(module.__name__ + '.' + item.name, fromlist=['']), prefix + item.name + '.')
|
||||
else:
|
||||
logging.info('reload module: {}'.format(prefix + item.name))
|
||||
importlib.reload(__import__(module.__name__ + '.' + item.name, fromlist=['']))
|
||||
|
||||
|
||||
def reload_all():
|
||||
# 解除bot的事件注册
|
||||
import pkg
|
||||
pkg.utils.context.get_qqbot_manager().unsubscribe_all()
|
||||
# 执行关闭流程
|
||||
logging.info("执行程序关闭流程")
|
||||
import main
|
||||
main.stop()
|
||||
|
||||
# 重载所有模块
|
||||
context = pkg.utils.context.context
|
||||
walk(pkg)
|
||||
importlib.reload(__import__('config'))
|
||||
importlib.reload(__import__('main'))
|
||||
pkg.utils.context.context = context
|
||||
|
||||
# 执行启动流程
|
||||
logging.info("执行程序启动流程")
|
||||
threading.Thread(target=main.main, args=(False,), daemon=False).start()
|
||||
|
||||
logging.info('程序启动完成')
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("重载完成")
|
||||
13
pkg/utils/updater.py
Normal file
13
pkg/utils/updater.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import dulwich.porcelain
|
||||
|
||||
|
||||
def update_all():
|
||||
"""使用dulwich更新源码"""
|
||||
try:
|
||||
from dulwich import porcelain
|
||||
repo = porcelain.open_repo('.')
|
||||
porcelain.pull(repo)
|
||||
except ModuleNotFoundError:
|
||||
raise Exception("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77")
|
||||
except dulwich.porcelain.DivergedBranches:
|
||||
raise Exception("分支不一致,自动更新仅支持master分支,请手动更新(https://github.com/RockChinQ/QChatGPT/issues/76)")
|
||||
Reference in New Issue
Block a user