Merge branch 'master' into multi-threads-control

This commit is contained in:
Rock Chin
2023-03-23 21:43:41 +08:00
committed by GitHub
39 changed files with 1361 additions and 436 deletions

View File

@@ -46,7 +46,7 @@ class DataGatherer:
config = pkg.utils.context.get_config()
if hasattr(config, "report_usage") and not config.report_usage:
return
res = requests.get("http://rockchin.top:18989/usage?service_name=qchatgpt.{}&version={}&count={}".format(subservice_name, self.version_str, count))
res = requests.get("http://reports.rockchin.top:18989/usage?service_name=qchatgpt.{}&version={}&count={}".format(subservice_name, self.version_str, count))
if res.status_code != 200 or res.text != "ok":
logging.warning("report to server failed, status_code: {}, text: {}".format(res.status_code, res.text))
except:

View File

@@ -35,6 +35,7 @@ class DatabaseManager:
def __execute__(self, *args, **kwargs) -> Cursor:
# logging.debug('SQL: {}'.format(sql))
logging.debug('SQL: {}'.format(args))
c = self.cursor.execute(*args, **kwargs)
self.conn.commit()
return c
@@ -52,10 +53,30 @@ class DatabaseManager:
`create_timestamp` bigint not null,
`last_interact_timestamp` bigint not null,
`status` varchar(255) not null default 'on_going',
`prompt` text not null
`default_prompt` text not null default '',
`prompt` text not null,
`token_counts` text not null default '[]'
)
""")
# 检查sessions表是否存在`default_prompt`字段, 检查是否存在`token_counts`字段
self.__execute__("PRAGMA table_info('sessions')")
columns = self.cursor.fetchall()
has_default_prompt = False
has_token_counts = False
for field in columns:
if field[1] == 'default_prompt':
has_default_prompt = True
if field[1] == 'token_counts':
has_token_counts = True
if has_default_prompt and has_token_counts:
break
if not has_default_prompt:
self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''")
if not has_token_counts:
self.__execute__("alter table `sessions` add column `token_counts` text not null default '[]'")
self.__execute__("""
create table if not exists `account_fee`(
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -75,7 +96,7 @@ class DatabaseManager:
# session持久化
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
last_interact_timestamp: int, prompt: str):
last_interact_timestamp: int, prompt: str, default_prompt: str = '', token_counts: str = ''):
"""持久化指定session"""
# 检查是否已经有了此name和create_timestamp的session
@@ -88,20 +109,20 @@ class DatabaseManager:
if count == 0:
sql = """
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`)
values (?, ?, ?, ?, ?, ?)
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`, `token_counts`)
values (?, ?, ?, ?, ?, ?, ?, ?)
"""
self.__execute__(sql,
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
last_interact_timestamp, prompt))
last_interact_timestamp, prompt, default_prompt, token_counts))
else:
sql = """
update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?
update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?, `token_counts` = ?
where `type` = ? and `number` = ? and `create_timestamp` = ?
"""
self.__execute__(sql, (last_interact_timestamp, prompt, subject_type,
self.__execute__(sql, (last_interact_timestamp, prompt, token_counts, subject_type,
subject_number, create_timestamp))
# 显式关闭一个session
@@ -126,7 +147,7 @@ class DatabaseManager:
# 从数据库中加载所有还没过期的session
config = pkg.utils.context.get_config()
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `last_interact_timestamp` > {}
""".format(int(time.time()) - config.session_expire_time))
results = self.cursor.fetchall()
@@ -139,6 +160,8 @@ class DatabaseManager:
last_interact_timestamp = result[4]
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
# 当且仅当最后一个该对象的会话是on_going状态时才会被加载
if status == 'on_going':
@@ -147,7 +170,9 @@ class DatabaseManager:
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt
'prompt': prompt,
'default_prompt': default_prompt,
'token_counts': token_counts
}
else:
if session_name in sessions:
@@ -159,7 +184,7 @@ class DatabaseManager:
def last_session(self, session_name: str, cursor_timestamp: int):
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc
limit 1
""".format(session_name, cursor_timestamp))
@@ -175,20 +200,24 @@ class DatabaseManager:
last_interact_timestamp = result[4]
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
return {
'subject_type': subject_type,
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt
'prompt': prompt,
'default_prompt': default_prompt,
'token_counts': token_counts
}
# 获取此session_name后一个session的数据
def next_session(self, session_name: str, cursor_timestamp: int):
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc
limit 1
""".format(session_name, cursor_timestamp))
@@ -204,19 +233,23 @@ class DatabaseManager:
last_interact_timestamp = result[4]
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
return {
'subject_type': subject_type,
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt
'prompt': prompt,
'default_prompt': default_prompt,
'token_counts': token_counts
}
# 列出与某个对象的所有对话session
def list_history(self, session_name: str, capacity: int, page: int):
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {}
""".format(session_name, capacity, capacity * page))
results = self.cursor.fetchall()
@@ -229,17 +262,42 @@ class DatabaseManager:
last_interact_timestamp = result[4]
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
sessions.append({
'subject_type': subject_type,
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt
'prompt': prompt,
'default_prompt': default_prompt,
'token_counts': token_counts
})
return sessions
def delete_history(self, session_name: str, index: int) -> bool:
# 删除倒序第index个session
# 查找其id再删除
self.__execute__("""
delete from `sessions` where `id` in (select `id` from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit 1 offset {})
""".format(session_name, index))
return self.cursor.rowcount == 1
def delete_all_history(self, session_name: str) -> bool:
self.__execute__("""
delete from `sessions` where `name` = '{}'
""".format(session_name))
return self.cursor.rowcount > 0
def delete_all_session_history(self) -> bool:
self.__execute__("""
delete from `sessions`
""")
return self.cursor.rowcount > 0
# 将apikey的使用量存进数据库
def dump_api_key_usage(self, api_keys: dict, usage: dict):
logging.debug('dumping api key usage...')

View File

@@ -1,4 +1,6 @@
# 多情景预设值管理
import json
import logging
__current__ = "default"
"""当前默认使用的情景预设的名称
@@ -9,8 +11,10 @@ __current__ = "default"
__prompts_from_files__ = {}
"""从文件中读取的情景预设值"""
__scenario_from_files__ = {}
def read_prompt_from_file() -> str:
def read_prompt_from_file():
"""从文件读取预设值"""
# 读取prompts/目录下的所有文件,以文件名为键,文件内容为值
# 保存在__prompts_from_files__中
@@ -23,6 +27,19 @@ def read_prompt_from_file() -> str:
__prompts_from_files__[file] = f.read()
def read_scenario_from_file():
"""从JSON文件读取情景预设"""
global __scenario_from_files__
import os
__scenario_from_files__ = {}
for file in os.listdir("scenario"):
if file == "default-template.json":
continue
with open(os.path.join("scenario", file), encoding="utf-8") as f:
__scenario_from_files__[file] = json.load(f)
def get_prompt_dict() -> dict:
"""获取预设值字典"""
import config
@@ -65,15 +82,40 @@ def set_to_default():
__current__ = list(default_dict.keys())[0]
def get_prompt(name: str = None) -> str:
def get_prompt(name: str = None) -> list:
global __scenario_from_files__
import config
preset_mode = config.preset_mode
"""获取预设值"""
if name is None:
name = get_current()
default_dict = get_prompt_dict()
# JSON预设方式
if preset_mode == 'full_scenario':
import os
for key in default_dict:
if key.lower().startswith(name.lower()):
return default_dict[key]
for key in __scenario_from_files__:
if key.lower().startswith(name.lower()):
logging.debug('成功加载情景预设从JSON文件: {}'.format(key))
return __scenario_from_files__[key]['prompt']
# 默认预设方式
elif preset_mode == 'default':
raise KeyError("未找到情景预设: " + name)
default_dict = get_prompt_dict()
for key in default_dict:
if key.lower().startswith(name.lower()):
return [
{
"role": "user",
"content": default_dict[key]
},
{
"role": "assistant",
"content": "好的。"
}
]
raise KeyError("未找到默认情景预设: " + name)

View File

@@ -88,4 +88,4 @@ class KeysManager:
for key_name in self.api_key:
if self.api_key[key_name] == api_key:
return key_name
return ""
return ""

View File

@@ -34,7 +34,7 @@ class OpenAIInteract:
pkg.utils.context.set_openai_manager(self)
# 请求OpenAI Completion
def request_completion(self, prompts) -> str:
def request_completion(self, prompts) -> tuple[str, int]:
"""请求补全接口回复
Parameters:
@@ -60,14 +60,18 @@ class OpenAIInteract:
logging.debug("OpenAI response: %s", response)
# 记录使用量
current_round_token = 0
if 'model' in config.completion_api_params:
self.audit_mgr.report_text_model_usage(config.completion_api_params['model'],
ai.get_total_tokens())
current_round_token = ai.get_total_tokens()
elif 'engine' in config.completion_api_params:
self.audit_mgr.report_text_model_usage(config.completion_api_params['engine'],
response['usage']['total_tokens'])
current_round_token = response['usage']['total_tokens']
return ai.get_message()
return ai.get_message(), current_round_token
def request_image(self, prompt) -> dict:
"""请求图片接口回复

View File

@@ -21,6 +21,10 @@ COMPLETION_MODELS = {
CHAT_COMPLETION_MODELS = {
'gpt-3.5-turbo',
'gpt-3.5-turbo-0301',
'gpt-4',
'gpt-4-0314',
'gpt-4-32k',
'gpt-4-32k-0314'
}
EDIT_MODELS = {

View File

@@ -40,7 +40,7 @@ def reset_session_prompt(session_name, prompt):
prompt = [
{
'role': 'system',
'content': config.default_prompt['default']
'content': config.default_prompt['default'] if type(config.default_prompt) == dict else config.default_prompt
}
]
# 警告
@@ -72,9 +72,12 @@ def load_sessions():
temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp']
try:
temp_session.prompt = json.loads(session_data[session_name]['prompt'])
temp_session.token_counts = json.loads(session_data[session_name]['token_counts'])
except Exception:
temp_session.prompt = reset_session_prompt(session_name, session_data[session_name]['prompt'])
temp_session.persistence()
temp_session.default_prompt = json.loads(session_data[session_name]['default_prompt']) if \
session_data[session_name]['default_prompt'] else []
sessions[session_name] = temp_session
@@ -104,6 +107,12 @@ class Session:
prompt = []
"""使用list来保存会话中的回合"""
token_counts = []
"""每个回合的token数量"""
default_prompt = []
"""本session的默认prompt"""
create_timestamp = 0
"""会话创建时间"""
@@ -129,33 +138,26 @@ class Session:
# 从配置文件获取会话预设信息
def get_default_prompt(self, use_default: str = None):
config = pkg.utils.context.get_config()
import pkg.openai.dprompt as dprompt
if use_default is None:
current_default_prompt = dprompt.get_prompt(dprompt.get_current())
else:
current_default_prompt = dprompt.get_prompt(use_default)
use_default = dprompt.get_current()
return [
{
'role': 'user',
'content': current_default_prompt
}, {
'role': 'assistant',
'content': 'ok'
}
]
current_default_prompt = dprompt.get_prompt(use_default)
return current_default_prompt
def __init__(self, name: str):
self.name = name
self.create_timestamp = int(time.time())
self.last_interact_timestamp = int(time.time())
self.prompt = []
self.token_counts = []
self.schedule()
self.response_lock = threading.Lock()
self.prompt = self.get_default_prompt()
self.default_prompt = self.get_default_prompt()
logging.debug("prompt is: {}".format(self.default_prompt))
# 设定检查session最后一次对话是否超过过期时间的计时器
def schedule(self):
@@ -199,11 +201,11 @@ class Session:
self.last_interact_timestamp = int(time.time())
# 触发插件事件
if self.prompt == self.get_default_prompt():
if not self.prompt:
args = {
'session_name': self.name,
'session': self,
'default_prompt': self.prompt,
'default_prompt': self.default_prompt,
}
event = pkg.plugin.host.emit(plugin_models.SessionFirstMessageReceived, **args)
@@ -213,9 +215,16 @@ class Session:
config = pkg.utils.context.get_config()
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
prompts, counts = self.cut_out(text, max_length)
# 计算请求前的prompt数量
total_token_before_query = 0
for token_count in counts:
total_token_before_query += token_count
# 向API请求补全
message = pkg.utils.context.get_openai_manager().request_completion(
self.cut_out(text, max_length),
message, total_token = pkg.utils.context.get_openai_manager().request_completion(
prompts,
)
# 成功获取,处理回复
@@ -232,6 +241,10 @@ class Session:
self.prompt.append({'role': 'user', 'content': text})
self.prompt.append({'role': 'assistant', 'content': res_ans})
# 向token_counts中添加本回合的token数量
self.token_counts.append(total_token-total_token_before_query)
logging.debug("本回合使用token: {}, session counts: {}".format(total_token-total_token_before_query, self.token_counts))
if self.just_switched_to_exist_session:
self.just_switched_to_exist_session = False
self.set_ongoing()
@@ -248,35 +261,61 @@ class Session:
question = self.prompt[-2]['content']
self.prompt = self.prompt[:-2]
self.token_counts = self.token_counts[:-1]
# 返回上一回合的问题
return question
# 构建对话体
def cut_out(self, msg: str, max_tokens: int) -> list:
"""将现有prompt进行切割处理使得新的prompt长度不超过max_tokens"""
# 如果用户消息长度超过max_tokens直接返回
def cut_out(self, msg: str, max_tokens: int) -> tuple[list, list]:
"""将现有prompt进行切割处理使得新的prompt长度不超过max_tokens
temp_prompt = [
:return: (新的prompt, 新的token_counts)
"""
# 最终由三个部分组成
# - default_prompt 情景预设固定值
# - changable_prompts 可变部分, 此会话中的历史对话回合
# - current_question 当前问题
# 包装目前的对话回合内容
changable_prompts = []
changable_counts = []
# 倒着来, 遍历prompt的步长为2, 遍历tokens_counts的步长为1
changable_index = len(self.prompt) - 1
token_count_index = len(self.token_counts) - 1
packed_tokens = 0
while changable_index >= 0 and token_count_index >= 0:
if packed_tokens + self.token_counts[token_count_index] > max_tokens:
break
changable_prompts.insert(0, self.prompt[changable_index])
changable_prompts.insert(0, self.prompt[changable_index - 1])
changable_counts.insert(0, self.token_counts[token_count_index])
packed_tokens += self.token_counts[token_count_index]
changable_index -= 2
token_count_index -= 1
# 将default_prompt和changable_prompts合并
result_prompt = self.default_prompt + changable_prompts
# 添加当前问题
result_prompt.append(
{
'role': 'user',
'content': msg
}
]
)
token_count = len(msg)
# 倒序遍历prompt
for i in range(len(self.prompt) - 1, -1, -1):
if token_count >= max_tokens:
break
logging.debug('cut_out: {}\nchangable section tokens: {}\npacked counts: {}\nsession counts: {}'.format(json.dumps(result_prompt, ensure_ascii=False, indent=4),
packed_tokens,
changable_counts,
self.token_counts))
# 将prompt加到temp_prompt头部
temp_prompt.insert(0, self.prompt[i])
token_count += len(self.prompt[i]['content'])
logging.debug('cut_out: {}'.format(str(temp_prompt)))
return temp_prompt
return result_prompt, changable_counts
# 持久化session
def persistence(self):
@@ -291,11 +330,11 @@ class Session:
subject_number = int(name_spt[1])
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
json.dumps(self.prompt))
json.dumps(self.prompt), json.dumps(self.default_prompt), json.dumps(self.token_counts))
# 重置session
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None):
if self.prompt[-1]['role'] != "system":
if self.prompt:
self.persistence()
if explicit:
# 触发插件事件
@@ -311,7 +350,10 @@ class Session:
if expired:
pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
self.prompt = self.get_default_prompt(use_prompt)
self.default_prompt = self.get_default_prompt(use_prompt)
self.prompt = []
self.token_counts = []
self.create_timestamp = int(time.time())
self.last_interact_timestamp = int(time.time())
self.just_switched_to_exist_session = False
@@ -337,9 +379,11 @@ class Session:
self.last_interact_timestamp = last_one['last_interact_timestamp']
try:
self.prompt = json.loads(last_one['prompt'])
self.token_counts = json.loads(last_one['token_counts'])
except json.decoder.JSONDecodeError:
self.prompt = reset_session_prompt(self.name, last_one['prompt'])
self.persistence()
self.default_prompt = json.loads(last_one['default_prompt']) if last_one['default_prompt'] else []
self.just_switched_to_exist_session = True
return self
@@ -356,9 +400,11 @@ class Session:
self.last_interact_timestamp = next_one['last_interact_timestamp']
try:
self.prompt = json.loads(next_one['prompt'])
self.token_counts = json.loads(next_one['token_counts'])
except json.decoder.JSONDecodeError:
self.prompt = reset_session_prompt(self.name, next_one['prompt'])
self.persistence()
self.default_prompt = json.loads(next_one['default_prompt']) if next_one['default_prompt'] else []
self.just_switched_to_exist_session = True
return self
@@ -366,5 +412,11 @@ class Session:
def list_history(self, capacity: int = 10, page: int = 0):
return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page)
def delete_history(self, index: int) -> bool:
return pkg.utils.context.get_database_manager().delete_history(self.name, index)
def delete_all_history(self) -> bool:
return pkg.utils.context.get_database_manager().delete_all_history(self.name)
def draw_image(self, prompt: str):
return pkg.utils.context.get_openai_manager().request_image(prompt)

View File

@@ -5,6 +5,7 @@ import importlib
import os
import pkgutil
import sys
import shutil
import traceback
import pkg.utils.context as context
@@ -160,6 +161,22 @@ def install_plugin(repo_url: str):
main.reset_logging()
def uninstall_plugin(plugin_name: str) -> str:
""" 卸载插件 """
if plugin_name not in __plugins__:
raise Exception("插件不存在")
# 获取文件夹路径
plugin_path = __plugins__[plugin_name]['path'].replace("\\", "/")
# 剪切路径为plugins/插件名
plugin_path = plugin_path.split("plugins/")[1].split("/")[0]
# 删除文件夹
shutil.rmtree("plugins/"+plugin_path)
return "plugins/"+plugin_path
class EventContext:
""" 事件上下文 """
eid = 0

View File

@@ -1,5 +1,4 @@
# 长消息处理相关
import logging
import os
import time
import base64
@@ -67,7 +66,7 @@ def check_text(text: str) -> list:
"""检查文本是否为长消息,并转换成该使用的消息链组件"""
if not hasattr(config, 'blob_message_threshold'):
return [text]
if len(text) > config.blob_message_threshold:
if not hasattr(config, 'blob_message_strategy'):
raise AttributeError('未定义长消息处理策略')
@@ -77,8 +76,6 @@ def check_text(text: str) -> list:
# 转换成图片
return [text_to_image(text)]
elif config.blob_message_strategy == 'forward':
# 敏感词屏蔽
text = context.get_qqbot_manager().reply_filter.process(text)
# 包装转发消息
display = ForwardMessageDiaplay(

View File

36
pkg/qqbot/cmds/func.py Normal file
View File

@@ -0,0 +1,36 @@
from pkg.qqbot.cmds.model import command
import logging
from mirai import Image
import config
import pkg.openai.session
@command(
"draw",
"使用DALL·E模型作画",
"!draw <图片提示语>",
[],
False
)
def cmd_draw(cmd: str, params: list, session_name: str,
text_message: str, launcher_type: str, launcher_id: int,
sender_id: int, is_admin: bool) -> list:
"""使用DALL·E模型作画"""
reply = []
if len(params) == 0:
reply = ["[bot]err:请输入图片描述文字"]
else:
session = pkg.openai.session.get_session(session_name)
res = session.draw_image(" ".join(params))
logging.debug("draw_image result:{}".format(res))
reply = [Image(url=res['data'][0]['url'])]
if not (hasattr(config, 'include_image_description')
and not config.include_image_description):
reply.append(" ".join(params))
return reply

45
pkg/qqbot/cmds/model.py Normal file
View File

@@ -0,0 +1,45 @@
# 指令模型
import logging
commands = []
"""已注册的指令类
{
"name": "指令名",
"description": "指令描述",
"usage": "指令用法",
"aliases": ["别名1", "别名2"],
"admin_only": "是否仅管理员可用",
"func": "指令执行函数"
}
"""
def command(name: str, description: str, usage: str, aliases: list = None, admin_only: bool = False):
"""指令装饰器"""
def wrapper(fun):
commands.append({
"name": name,
"description": description,
"usage": usage,
"aliases": aliases,
"admin_only": admin_only,
"func": fun
})
return fun
return wrapper
def search(cmd: str) -> dict:
"""查找指令"""
for command in commands:
if (command["name"] == cmd) or (cmd in command["aliases"]):
return command
return None
import pkg.qqbot.cmds.func
import pkg.qqbot.cmds.system
import pkg.qqbot.cmds.session
import pkg.qqbot.cmds.plugin

129
pkg/qqbot/cmds/plugin.py Normal file
View File

@@ -0,0 +1,129 @@
from pkg.qqbot.cmds.model import command
import pkg.utils.context
import pkg.plugin.switch as plugin_switch
import os
import threading
import logging
def plugin_operation(cmd, params, is_admin):
reply = []
import pkg.plugin.host as plugin_host
import pkg.utils.updater as updater
plugin_list = plugin_host.__plugins__
if len(params) == 0:
reply_str = "[bot]所有插件({}):\n".format(len(plugin_host.__plugins__))
idx = 0
for key in plugin_host.iter_plugins_name():
plugin = plugin_list[key]
reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\
.format((idx+1), plugin['name'],
"[已禁用]" if not plugin['enabled'] else "",
plugin['description'],
plugin['version'], plugin['author'])
if updater.is_repo("/".join(plugin['path'].split('/')[:-1])):
remote_url = updater.get_remote_url("/".join(plugin['path'].split('/')[:-1]))
if remote_url != "https://github.com/RockChinQ/QChatGPT" and remote_url != "https://gitee.com/RockChin/QChatGPT":
reply_str += "源码: "+remote_url+"\n"
idx += 1
reply = [reply_str]
elif params[0] == 'update':
# 更新所有插件
if is_admin:
def closure():
import pkg.utils.context
updated = []
for key in plugin_list:
plugin = plugin_list[key]
if updater.is_repo("/".join(plugin['path'].split('/')[:-1])):
success = updater.pull_latest("/".join(plugin['path'].split('/')[:-1]))
if success:
updated.append(plugin['name'])
# 检查是否有requirements.txt
pkg.utils.context.get_qqbot_manager().notify_admin("正在安装依赖...")
for key in plugin_list:
plugin = plugin_list[key]
if os.path.exists("/".join(plugin['path'].split('/')[:-1])+"/requirements.txt"):
logging.info("{}检测到requirements.txt安装依赖".format(plugin['name']))
import pkg.utils.pkgmgr
pkg.utils.pkgmgr.install_requirements("/".join(plugin['path'].split('/')[:-1])+"/requirements.txt")
import main
main.reset_logging()
pkg.utils.context.get_qqbot_manager().notify_admin("已更新插件: {}".format(", ".join(updated)))
threading.Thread(target=closure).start()
reply = ["[bot]正在更新所有插件,请勿重复发起..."]
else:
reply = ["[bot]err:权限不足"]
elif params[0] == 'del' or params[0] == 'delete':
if is_admin:
if len(params) < 2:
reply = ["[bot]err:未指定插件名"]
else:
plugin_name = params[1]
if plugin_name in plugin_list:
unin_path = plugin_host.uninstall_plugin(plugin_name)
reply = ["[bot]已删除插件: {} ({}), 请发送 !reload 重载插件".format(plugin_name, unin_path)]
else:
reply = ["[bot]err:未找到插件: {}, 请使用!plugin指令查看插件列表".format(plugin_name)]
else:
reply = ["[bot]err:权限不足,请使用管理员账号私聊发起"]
elif params[0] == 'on' or params[0] == 'off' :
new_status = params[0] == 'on'
if is_admin:
if len(params) < 2:
reply = ["[bot]err:未指定插件名"]
else:
plugin_name = params[1]
if plugin_name in plugin_list:
plugin_list[plugin_name]['enabled'] = new_status
plugin_switch.dump_switch()
reply = ["[bot]已{}插件: {}".format("启用" if new_status else "禁用", plugin_name)]
else:
reply = ["[bot]err:未找到插件: {}, 请使用!plugin指令查看插件列表".format(plugin_name)]
else:
reply = ["[bot]err:权限不足,请使用管理员账号私聊发起"]
elif params[0].startswith("http"):
if is_admin:
def closure():
try:
plugin_host.install_plugin(params[0])
pkg.utils.context.get_qqbot_manager().notify_admin("插件安装成功,请发送 !reload 指令重载插件")
except Exception as e:
logging.error("插件安装失败:{}".format(e))
pkg.utils.context.get_qqbot_manager().notify_admin("插件安装失败:{}".format(e))
threading.Thread(target=closure, args=()).start()
reply = ["[bot]正在安装插件..."]
else:
reply = ["[bot]err:权限不足,请使用管理员账号私聊发起"]
else:
reply = ["[bot]err:未知参数: {}".format(params)]
return reply
@command(
"plugin",
"插件相关操作",
"!plugin\n!plugin <插件仓库地址>\!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>",
[],
False
)
def cmd_plugin(cmd: str, params: list, session_name: str,
text_message: str, launcher_type: str, launcher_id: int,
sender_id: int, is_admin: bool) -> list:
"""插件相关操作"""
reply = plugin_operation(cmd, params, is_admin)
return reply

282
pkg/qqbot/cmds/session.py Normal file
View File

@@ -0,0 +1,282 @@
# 会话管理相关指令
import datetime
import json
from pkg.qqbot.cmds.model import command
import pkg.openai.session
import pkg.utils.context
import config
@command(
"reset",
"重置当前会话",
"!reset\n!reset [使用情景预设名称]",
[],
False
)
def cmd_reset(cmd: str, params: list, session_name: str,
text_message: str, launcher_type: str, launcher_id: int,
sender_id: int, is_admin: bool) -> list:
"""重置会话"""
reply = []
if len(params) == 0:
pkg.openai.session.get_session(session_name).reset(explicit=True)
reply = ["[bot]会话已重置"]
else:
pkg.openai.session.get_session(session_name).reset(explicit=True, use_prompt=params[0])
reply = ["[bot]会话已重置,使用场景预设:{}".format(params[0])]
return reply
@command(
"last",
"切换到前一次会话",
"!last",
[],
False
)
def cmd_last(cmd: str, params: list, session_name: str,
text_message: str, launcher_type: str, launcher_id: int,
sender_id: int, is_admin: bool) -> list:
"""切换到前一次会话"""
reply = []
result = pkg.openai.session.get_session(session_name).last_session()
if result is None:
reply = ["[bot]没有前一次的对话"]
else:
datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime(
'%Y-%m-%d %H:%M:%S')
reply = ["[bot]已切换到前一次的对话:\n创建时间:{}\n".format(datetime_str)]
return reply
@command(
"next",
"切换到后一次会话",
"!next",
[],
False
)
def cmd_next(cmd: str, params: list, session_name: str,
text_message: str, launcher_type: int, launcher_id: int,
sender_id: int, is_admin: bool) -> list:
"""切换到后一次会话"""
reply = []
result = pkg.openai.session.get_session(session_name).next_session()
if result is None:
reply = ["[bot]没有后一次的对话"]
else:
datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime(
'%Y-%m-%d %H:%M:%S')
reply = ["[bot]已切换到后一次的对话:\n创建时间:{}\n".format(datetime_str)]
return reply
@command(
"prompt",
"获取当前会话的前文",
"!prompt",
[],
False
)
def cmd_prompt(cmd: str, params: list, session_name: str,
text_message: str, launcher_type: str, launcher_id: int,
sender_id: int, is_admin: bool) -> list:
"""获取当前会话的前文"""
reply = []
msgs = ""
session:list = pkg.openai.session.get_session(session_name).prompt
for msg in session:
if len(params) != 0 and params[0] in ['-all', '-a']:
msgs = msgs + "{}: {}\n\n".format(msg['role'], msg['content'])
elif len(msg['content']) > 30:
msgs = msgs + "[{}]: {}...\n\n".format(msg['role'], msg['content'][:30])
else:
msgs = msgs + "[{}]: {}\n\n".format(msg['role'], msg['content'])
reply = ["[bot]当前对话所有内容:\n{}".format(msgs)]
return reply
@command(
"list",
"列出当前会话的所有历史记录",
"!list\n!list [页数]",
[],
False
)
def cmd_list(cmd: str, params: list, session_name: str,
text_message: str, launcher_type: str, launcher_id: int,
sender_id: int, is_admin: bool) -> list:
"""列出当前会话的所有历史记录"""
reply = []
pkg.openai.session.get_session(session_name).persistence()
page = 0
if len(params) > 0:
try:
page = int(params[0])
except ValueError:
pass
results = pkg.openai.session.get_session(session_name).list_history(page=page)
if len(results) == 0:
reply = ["[bot]第{}页没有历史会话".format(page)]
else:
reply_str = "[bot]历史会话 第{}页:\n".format(page)
current = -1
for i in range(len(results)):
# 时间(使用create_timestamp转换) 序号 部分内容
datetime_obj = datetime.datetime.fromtimestamp(results[i]['create_timestamp'])
msg = ""
try:
msg = json.loads(results[i]['prompt'])
except json.decoder.JSONDecodeError:
msg = pkg.openai.session.reset_session_prompt(session_name, results[i]['prompt'])
# 持久化
pkg.openai.session.get_session(session_name).persistence()
if len(msg) >= 2:
reply_str += "#{} 创建:{} {}\n".format(i + page * 10,
datetime_obj.strftime("%Y-%m-%d %H:%M:%S"),
msg[0]['content'])
else:
reply_str += "#{} 创建:{} {}\n".format(i + page * 10,
datetime_obj.strftime("%Y-%m-%d %H:%M:%S"),
"无内容")
if results[i]['create_timestamp'] == pkg.openai.session.get_session(
session_name).create_timestamp:
current = i + page * 10
reply_str += "\n以上信息倒序排列"
if current != -1:
reply_str += ",当前会话是 #{}\n".format(current)
else:
reply_str += ",当前处于全新会话或不在此页"
reply = [reply_str]
return reply
@command(
"resend",
"重新获取上一次问题的回复",
"!resend",
[],
False
)
def cmd_resend(cmd: str, params: list, session_name: str,
text_message: str, launcher_type: str, launcher_id: int,
sender_id: int, is_admin: bool) -> list:
"""重新获取上一次问题的回复"""
reply = []
session = pkg.openai.session.get_session(session_name)
to_send = session.undo()
mgr = pkg.utils.context.get_qqbot_manager()
reply = pkg.qqbot.message.process_normal_message(to_send, mgr, config,
launcher_type, launcher_id, sender_id)
return reply
@command(
"del",
"删除当前会话的历史记录",
"!del <序号>\n!del all",
[],
False
)
def cmd_del(cmd: str, params: list, session_name: str,
text_message: str, launcher_type: str, launcher_id: int,
sender_id: int, is_admin: bool) -> list:
"""删除当前会话的历史记录"""
reply = []
if len(params) == 0:
reply = ["[bot]参数不足, 格式: !del <序号>\n可以通过!list查看序号"]
else:
if params[0] == 'all':
pkg.openai.session.get_session(session_name).delete_all_history()
reply = ["[bot]已删除所有历史会话"]
elif params[0].isdigit():
if pkg.openai.session.get_session(session_name).delete_history(int(params[0])):
reply = ["[bot]已删除历史会话 #{}".format(params[0])]
else:
reply = ["[bot]没有历史会话 #{}".format(params[0])]
else:
reply = ["[bot]参数错误, 格式: !del <序号>\n可以通过!list查看序号"]
return reply
@command(
"default",
"操作情景预设",
"!default\n!default [指定情景预设为默认]",
[],
False
)
def cmd_default(cmd: str, params: list, session_name: str,
text_message: str, launcher_type: str, launcher_id: int,
sender_id: int, is_admin: bool) -> list:
"""操作情景预设"""
reply = []
if len(params) == 0:
# 输出目前所有情景预设
import pkg.openai.dprompt as dprompt
reply_str = "[bot]当前所有情景预设:\n\n"
for key,value in dprompt.get_prompt_dict().items():
reply_str += " - {}: {}\n".format(key,value)
reply_str += "\n当前默认情景预设:{}\n".format(dprompt.get_current())
reply_str += "请使用!default <情景预设>来设置默认情景预设"
reply = [reply_str]
elif len(params) >0 and is_admin:
# 设置默认情景
import pkg.openai.dprompt as dprompt
try:
dprompt.set_current(params[0])
reply = ["[bot]已设置默认情景预设为:{}".format(dprompt.get_current())]
except KeyError:
reply = ["[bot]err: 未找到情景预设:{}".format(params[0])]
else:
reply = ["[bot]err: 仅管理员可设置默认情景预设"]
return reply
@command(
"delhst",
"删除指定会话的所有历史记录",
"!delhst <会话名称>\n!delhst all",
[],
True
)
def cmd_delhst(cmd: str, params: list, session_name: str,
text_message: str, launcher_type: str, launcher_id: int,
sender_id: int, is_admin: bool) -> list:
"""删除指定会话的所有历史记录"""
reply = []
if len(params) == 0:
reply = ["[bot]err:请输入要删除的会话名: group_<群号> 或者 person_<QQ号>, 或使用 !delhst all 删除所有会话的历史记录"]
else:
if params[0] == "all":
pkg.utils.context.get_database_manager().delete_all_session_history()
reply = ["[bot]已删除所有会话的历史记录"]
else:
if pkg.utils.context.get_database_manager().delete_all_history(params[0]):
reply = ["[bot]已删除会话 {} 的所有历史记录".format(params[0])]
else:
reply = ["[bot]未找到会话 {} 的历史记录".format(params[0])]
return reply

216
pkg/qqbot/cmds/system.py Normal file
View File

@@ -0,0 +1,216 @@
from pkg.qqbot.cmds.model import command
import pkg.utils.context
import pkg.utils.updater
import pkg.utils.credit as credit
import config
import logging
import os
import threading
import traceback
import json
@command(
"help",
"获取帮助信息",
"!help",
[],
False
)
def cmd_help(cmd: str, params: list, session_name: str,
text_message: str, launcher_type: str, launcher_id: int,
sender_id: int, is_admin: bool) -> list:
"""获取帮助信息"""
return ["[bot]" + config.help_message]
@command(
"usage",
"获取使用情况",
"!usage",
[],
False
)
def cmd_usage(cmd: str, params: list, session_name: str,
text_message: str, launcher_type: str, launcher_id: int,
sender_id: int, is_admin: bool) -> list:
"""获取使用情况"""
reply = []
reply_str = "[bot]各api-key使用情况:\n\n"
api_keys = pkg.utils.context.get_openai_manager().key_mgr.api_key
for key_name in api_keys:
text_length = pkg.utils.context.get_openai_manager().audit_mgr \
.get_text_length_of_key(api_keys[key_name])
image_count = pkg.utils.context.get_openai_manager().audit_mgr \
.get_image_count_of_key(api_keys[key_name])
reply_str += "{}:\n - 文本长度:{}\n - 图片数量:{}\n".format(key_name, int(text_length),
int(image_count))
# 获取此key的额度
try:
http_proxy = config.openai_config["http_proxy"] if "http_proxy" in config.openai_config else None
credit_data = credit.fetch_credit_data(api_keys[key_name], http_proxy)
reply_str += " - 使用额度:{:.2f}/{:.2f}\n".format(credit_data['total_used'],credit_data['total_granted'])
except Exception as e:
logging.warning("获取额度失败:{}".format(e))
reply = [reply_str]
return reply
@command(
"version",
"查看版本信息",
"!version",
[],
False
)
def cmd_version(cmd: str, params: list, session_name: str,
text_message: str, launcher_type: str, launcher_id: int,
sender_id: int, is_admin: bool) -> list:
"""查看版本信息"""
reply = []
reply_str = "[bot]当前版本:\n{}\n".format(pkg.utils.updater.get_current_version_info())
try:
if pkg.utils.updater.is_new_version_available():
reply_str += "\n有新版本可用,请使用命令 !update 进行更新"
except:
pass
reply = [reply_str]
return reply
@command(
"reload",
"执行热重载",
"!reload",
[],
True
)
def cmd_reload(cmd: str, params: list, session_name: str,
text_message: str, launcher_type: str, launcher_id: int,
sender_id: int, is_admin: bool) -> list:
"""执行热重载"""
import pkg.utils.reloader
def reload_task():
pkg.utils.reloader.reload_all()
threading.Thread(target=reload_task, daemon=True).start()
@command(
"update",
"更新程序",
"!update",
[],
True
)
def cmd_update(cmd: str, params: list, session_name: str,
text_message: str, launcher_type: str, launcher_id: int,
sender_id: int, is_admin: bool) -> list:
"""更新程序"""
reply = []
import pkg.utils.updater
import pkg.utils.reloader
import pkg.utils.context
def update_task():
try:
if pkg.utils.updater.update_all():
pkg.utils.reloader.reload_all(notify=False)
pkg.utils.context.get_qqbot_manager().notify_admin("更新完成")
else:
pkg.utils.context.get_qqbot_manager().notify_admin("无新版本")
except Exception as e0:
traceback.print_exc()
pkg.utils.context.get_qqbot_manager().notify_admin("更新失败:{}".format(e0))
return
threading.Thread(target=update_task, daemon=True).start()
reply = ["[bot]正在更新,请耐心等待,请勿重复发起更新..."]
def config_operation(cmd, params):
reply = []
config = pkg.utils.context.get_config()
reply_str = ""
if len(params) == 0:
reply = ["[bot]err:请输入配置项"]
else:
cfg_name = params[0]
if cfg_name == 'all':
reply_str = "[bot]所有配置项:\n\n"
for cfg in dir(config):
if not cfg.startswith('__') and not cfg == 'logging':
# 根据配置项类型进行格式化如果是字典则转换为json并格式化
if isinstance(getattr(config, cfg), str):
reply_str += "{}: \"{}\"\n".format(cfg, getattr(config, cfg))
elif isinstance(getattr(config, cfg), dict):
# 不进行unicode转义并格式化
reply_str += "{}: {}\n".format(cfg,
json.dumps(getattr(config, cfg),
ensure_ascii=False, indent=4))
else:
reply_str += "{}: {}\n".format(cfg, getattr(config, cfg))
reply = [reply_str]
elif cfg_name in dir(config):
if len(params) == 1:
# 按照配置项类型进行格式化
if isinstance(getattr(config, cfg_name), str):
reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, getattr(config, cfg_name))
elif isinstance(getattr(config, cfg_name), dict):
reply_str = "[bot]配置项{}: {}\n".format(cfg_name,
json.dumps(getattr(config, cfg_name),
ensure_ascii=False, indent=4))
else:
reply_str = "[bot]配置项{}: {}\n".format(cfg_name, getattr(config, cfg_name))
reply = [reply_str]
else:
cfg_value = " ".join(params[1:])
# 类型转换如果是json则转换为字典
if cfg_value == 'true':
cfg_value = True
elif cfg_value == 'false':
cfg_value = False
elif cfg_value.isdigit():
cfg_value = int(cfg_value)
elif cfg_value.startswith('{') and cfg_value.endswith('}'):
cfg_value = json.loads(cfg_value)
else:
try:
cfg_value = float(cfg_value)
except ValueError:
pass
# 检查类型是否匹配
if isinstance(getattr(config, cfg_name), type(cfg_value)):
setattr(config, cfg_name, cfg_value)
pkg.utils.context.set_config(config)
reply = ["[bot]配置项{}修改成功".format(cfg_name)]
else:
reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)]
else:
reply = ["[bot]err:未找到配置项 {}".format(cfg_name)]
return reply
@command(
"cfg",
"配置文件相关操作",
"!cfg all\n!cfg <配置项名称>\n!cfg <配置项名称> <配置项新值>",
[],
True
)
def cmd_cfg(cmd: str, params: list, session_name: str,
text_message: str, launcher_type: str, launcher_id: int,
sender_id: int, is_admin: bool) -> list:
"""配置文件相关操作"""
reply = config_operation(cmd, params)
return reply

View File

@@ -4,6 +4,7 @@ import json
import datetime
import os
import threading
import traceback
import pkg.openai.session
import pkg.openai.manager
@@ -12,151 +13,11 @@ import pkg.utils.updater
import pkg.utils.context
import pkg.qqbot.message
import pkg.utils.credit as credit
import pkg.qqbot.cmds.model as cmdmodel
from mirai import Image
def config_operation(cmd, params):
reply = []
config = pkg.utils.context.get_config()
reply_str = ""
if len(params) == 0:
reply = ["[bot]err:请输入配置项"]
else:
cfg_name = params[0]
if cfg_name == 'all':
reply_str = "[bot]所有配置项:\n\n"
for cfg in dir(config):
if not cfg.startswith('__') and not cfg == 'logging':
# 根据配置项类型进行格式化如果是字典则转换为json并格式化
if isinstance(getattr(config, cfg), str):
reply_str += "{}: \"{}\"\n".format(cfg, getattr(config, cfg))
elif isinstance(getattr(config, cfg), dict):
# 不进行unicode转义并格式化
reply_str += "{}: {}\n".format(cfg,
json.dumps(getattr(config, cfg),
ensure_ascii=False, indent=4))
else:
reply_str += "{}: {}\n".format(cfg, getattr(config, cfg))
reply = [reply_str]
elif cfg_name in dir(config):
if len(params) == 1:
# 按照配置项类型进行格式化
if isinstance(getattr(config, cfg_name), str):
reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, getattr(config, cfg_name))
elif isinstance(getattr(config, cfg_name), dict):
reply_str = "[bot]配置项{}: {}\n".format(cfg_name,
json.dumps(getattr(config, cfg_name),
ensure_ascii=False, indent=4))
else:
reply_str = "[bot]配置项{}: {}\n".format(cfg_name, getattr(config, cfg_name))
reply = [reply_str]
else:
cfg_value = " ".join(params[1:])
# 类型转换如果是json则转换为字典
if cfg_value == 'true':
cfg_value = True
elif cfg_value == 'false':
cfg_value = False
elif cfg_value.isdigit():
cfg_value = int(cfg_value)
elif cfg_value.startswith('{') and cfg_value.endswith('}'):
cfg_value = json.loads(cfg_value)
else:
try:
cfg_value = float(cfg_value)
except ValueError:
pass
# 检查类型是否匹配
if isinstance(getattr(config, cfg_name), type(cfg_value)):
setattr(config, cfg_name, cfg_value)
pkg.utils.context.set_config(config)
reply = ["[bot]配置项{}修改成功".format(cfg_name)]
else:
reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)]
else:
reply = ["[bot]err:未找到配置项 {}".format(cfg_name)]
return reply
def plugin_operation(cmd, params, is_admin):
reply = []
import pkg.plugin.host as plugin_host
import pkg.utils.updater as updater
plugin_list = plugin_host.__plugins__
if len(params) == 0:
reply_str = "[bot]所有插件({}):\n".format(len(plugin_host.__plugins__))
idx = 0
for key in plugin_host.iter_plugins_name():
plugin = plugin_list[key]
reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\
.format((idx+1), plugin['name'],
"[已禁用]" if not plugin['enabled'] else "",
plugin['description'],
plugin['version'], plugin['author'])
if updater.is_repo("/".join(plugin['path'].split('/')[:-1])):
remote_url = updater.get_remote_url("/".join(plugin['path'].split('/')[:-1]))
if remote_url != "https://github.com/RockChinQ/QChatGPT" and remote_url != "https://gitee.com/RockChin/QChatGPT":
reply_str += "源码: "+remote_url+"\n"
idx += 1
reply = [reply_str]
elif params[0] == 'update':
# 更新所有插件
if is_admin:
def closure():
import pkg.utils.context
updated = []
for key in plugin_list:
plugin = plugin_list[key]
if updater.is_repo("/".join(plugin['path'].split('/')[:-1])):
success = updater.pull_latest("/".join(plugin['path'].split('/')[:-1]))
if success:
updated.append(plugin['name'])
# 检查是否有requirements.txt
pkg.utils.context.get_qqbot_manager().notify_admin("正在安装依赖...")
for key in plugin_list:
plugin = plugin_list[key]
if os.path.exists("/".join(plugin['path'].split('/')[:-1])+"/requirements.txt"):
logging.info("{}检测到requirements.txt安装依赖".format(plugin['name']))
import pkg.utils.pkgmgr
pkg.utils.pkgmgr.install_requirements("/".join(plugin['path'].split('/')[:-1])+"/requirements.txt")
import main
main.reset_logging()
pkg.utils.context.get_qqbot_manager().notify_admin("已更新插件: {}".format(", ".join(updated)))
threading.Thread(target=closure).start()
reply = ["[bot]正在更新所有插件,请勿重复发起..."]
else:
reply = ["[bot]err:权限不足"]
elif params[0].startswith("http"):
if is_admin:
def closure():
try:
plugin_host.install_plugin(params[0])
pkg.utils.context.get_qqbot_manager().notify_admin("插件安装成功,请发送 !reload 指令重载插件")
except Exception as e:
logging.error("插件安装失败:{}".format(e))
pkg.utils.context.get_qqbot_manager().notify_admin("插件安装失败:{}".format(e))
threading.Thread(target=closure, args=()).start()
reply = ["[bot]正在安装插件..."]
else:
reply = ["[bot]err:权限不足,请使用管理员账号私聊发起"]
return reply
def process_command(session_name: str, text_message: str, mgr, config,
launcher_type: str, launcher_id: int, sender_id: int, is_admin: bool) -> list:
@@ -169,188 +30,30 @@ def process_command(session_name: str, text_message: str, mgr, config,
cmd = text_message[1:].strip().split(' ')[0]
params = text_message[1:].strip().split(' ')[1:]
if cmd == 'help':
reply = ["[bot]" + config.help_message]
elif cmd == 'reset':
if len(params) == 0:
pkg.openai.session.get_session(session_name).reset(explicit=True)
reply = ["[bot]会话已重置"]
else:
pkg.openai.session.get_session(session_name).reset(explicit=True, use_prompt=params[0])
reply = ["[bot]会话已重置,使用场景预设:{}".format(params[0])]
elif cmd == 'last':
result = pkg.openai.session.get_session(session_name).last_session()
if result is None:
reply = ["[bot]没有前一次的对话"]
else:
datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime(
'%Y-%m-%d %H:%M:%S')
reply = ["[bot]已切换到前一次的对话:\n创建时间:{}\n".format(datetime_str)]
elif cmd == 'next':
result = pkg.openai.session.get_session(session_name).next_session()
if result is None:
reply = ["[bot]没有后一次的对话"]
else:
datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime(
'%Y-%m-%d %H:%M:%S')
reply = ["[bot]已切换到后一次的对话:\n创建时间:{}\n".format(datetime_str)]
elif cmd == 'prompt':
msgs = ""
session:list = pkg.openai.session.get_session(session_name).prompt
for msg in session:
if len(params) != 0 and params[0] in ['-all', '-a']:
msgs = msgs + "{}: {}\n\n".format(msg['role'], msg['content'])
elif len(msg['content']) > 30:
msgs = msgs + "[{}]: {}...\n\n".format(msg['role'], msg['content'][:30])
else:
msgs = msgs + "[{}]: {}\n\n".format(msg['role'], msg['content'])
reply = ["[bot]当前对话所有内容:\n{}".format(msgs)]
elif cmd == 'list':
pkg.openai.session.get_session(session_name).persistence()
page = 0
if len(params) > 0:
try:
page = int(params[0])
except ValueError:
pass
# 把!~开头的转换成!cfg
if cmd.startswith('~'):
params = [cmd[1:]] + params
cmd = 'cfg'
results = pkg.openai.session.get_session(session_name).list_history(page=page)
if len(results) == 0:
reply = ["[bot]第{}页没有历史会话".format(page)]
else:
reply_str = "[bot]历史会话 第{}页:\n".format(page)
current = -1
for i in range(len(results)):
# 时间(使用create_timestamp转换) 序号 部分内容
datetime_obj = datetime.datetime.fromtimestamp(results[i]['create_timestamp'])
msg = ""
try:
msg = json.loads(results[i]['prompt'])
except json.decoder.JSONDecodeError:
msg = pkg.openai.session.reset_session_prompt(session_name, results[i]['prompt'])
# 持久化
pkg.openai.session.get_session(session_name).persistence()
if len(msg) >= 2:
reply_str += "#{} 创建:{} {}\n".format(i + page * 10,
datetime_obj.strftime("%Y-%m-%d %H:%M:%S"),
msg[1]['content'])
else:
reply_str += "#{} 创建:{} {}\n".format(i + page * 10,
datetime_obj.strftime("%Y-%m-%d %H:%M:%S"),
"无内容")
if results[i]['create_timestamp'] == pkg.openai.session.get_session(
session_name).create_timestamp:
current = i + page * 10
reply_str += "\n以上信息倒序排列"
if current != -1:
reply_str += ",当前会话是 #{}\n".format(current)
else:
reply_str += ",当前处于全新会话或不在此页"
reply = [reply_str]
elif cmd == 'resend':
session = pkg.openai.session.get_session(session_name)
to_send = session.undo()
reply = pkg.qqbot.message.process_normal_message(to_send, mgr, config,
launcher_type, launcher_id, sender_id)
elif cmd == 'usage':
reply_str = "[bot]各api-key使用情况:\n\n"
api_keys = pkg.utils.context.get_openai_manager().key_mgr.api_key
for key_name in api_keys:
text_length = pkg.utils.context.get_openai_manager().audit_mgr \
.get_text_length_of_key(api_keys[key_name])
image_count = pkg.utils.context.get_openai_manager().audit_mgr \
.get_image_count_of_key(api_keys[key_name])
reply_str += "{}:\n - 文本长度:{}\n - 图片数量:{}\n".format(key_name, int(text_length),
int(image_count))
# 获取此key的额度
try:
credit_data = credit.fetch_credit_data(api_keys[key_name])
reply_str += " - 使用额度:{:.2f}/{:.2f}\n".format(credit_data['total_used'],credit_data['total_granted'])
except Exception as e:
logging.warning("获取额度失败:{}".format(e))
reply = [reply_str]
elif cmd == 'draw':
if len(params) == 0:
reply = ["[bot]err:请输入图片描述文字"]
else:
session = pkg.openai.session.get_session(session_name)
res = session.draw_image(" ".join(params))
logging.debug("draw_image result:{}".format(res))
reply = [Image(url=res['data'][0]['url'])]
if not (hasattr(config, 'include_image_description')
and not config.include_image_description):
reply.append(" ".join(params))
elif cmd == 'version':
reply_str = "[bot]当前版本:\n{}\n".format(pkg.utils.updater.get_current_version_info())
try:
if pkg.utils.updater.is_new_version_available():
reply_str += "\n有新版本可用,请使用命令 !update 进行更新"
except:
pass
reply = [reply_str]
elif cmd == 'plugin':
reply = plugin_operation(cmd, params, is_admin)
elif cmd == 'default':
if len(params) == 0:
# 输出目前所有情景预设
import pkg.openai.dprompt as dprompt
reply_str = "[bot]当前所有情景预设:\n\n"
for key,value in dprompt.get_prompt_dict().items():
reply_str += " - {}: {}\n".format(key,value)
reply_str += "\n当前默认情景预设:{}\n".format(dprompt.get_current())
reply_str += "请使用!default <情景预设>来设置默认情景预设"
reply = [reply_str]
elif len(params) >0 and is_admin:
# 设置默认情景
import pkg.openai.dprompt as dprompt
try:
dprompt.set_current(params[0])
reply = ["[bot]已设置默认情景预设为:{}".format(dprompt.get_current())]
except KeyError:
reply = ["[bot]err: 未找到情景预设:{}".format(params[0])]
else:
reply = ["[bot]err: 仅管理员可设置默认情景预设"]
elif cmd == 'reload' and is_admin:
def reload_task():
pkg.utils.reloader.reload_all()
threading.Thread(target=reload_task, daemon=True).start()
elif cmd == 'update' and is_admin:
def update_task():
try:
if pkg.utils.updater.update_all():
pkg.utils.reloader.reload_all(notify=False)
pkg.utils.context.get_qqbot_manager().notify_admin("更新完成")
else:
pkg.utils.context.get_qqbot_manager().notify_admin("无新版本")
except Exception as e0:
pkg.utils.context.get_qqbot_manager().notify_admin("更新失败:{}".format(e0))
return
threading.Thread(target=update_task, daemon=True).start()
reply = ["[bot]正在更新,请耐心等待,请勿重复发起更新..."]
elif cmd == 'cfg' and is_admin:
reply = config_operation(cmd, params)
# 选择指令处理函数
cmd_obj = cmdmodel.search(cmd)
if cmd_obj is not None and (cmd_obj['admin_only'] is False or is_admin):
cmd_func = cmd_obj['func']
reply = cmd_func(
cmd=cmd,
params=params,
session_name=session_name,
text_message=text_message,
launcher_type=launcher_type,
launcher_id=launcher_id,
sender_id=sender_id,
is_admin=is_admin,
)
else:
if cmd.startswith("~") and is_admin:
config_item = cmd[1:]
params = [config_item] + params
reply = config_operation("cfg", params)
else:
reply = ["[bot]err:未知的指令或权限不足: " + cmd]
reply = ["[bot]err:未知的指令或权限不足: " + cmd]
return reply
except Exception as e:
mgr.notify_admin("{}指令执行失败:{}".format(session_name, e))
logging.exception(e)

View File

@@ -1,6 +1,5 @@
# 普通消息处理模块
import logging
import time
import openai
import pkg.utils.context
import pkg.openai.session
@@ -64,7 +63,7 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str,
reply = event.get_return_value("reply")
if not event.is_prevented_default():
reply = blob.check_text(prefix + text)
reply = [prefix + text]
except openai.error.APIConnectionError as e:
err_msg = str(e)
if err_msg.__contains__('Error communicating with OpenAI'):
@@ -117,8 +116,7 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str,
reply = handle_exception("{}会话调用API失败:{}".format(session_name, e),
"[bot]err:RateLimitError,请重试或联系作者,或等待修复")
except openai.error.InvalidRequestError as e:
reply = handle_exception("{}API调用参数错误:{}\n\n这可能是由于config.py中的prompt_submit_length参数或"
"completion_api_params中的max_tokens参数数值过大导致的请尝试将其降低".format(
reply = handle_exception("{}API调用参数错误:{}\n".format(
session_name, e), "[bot]err:API调用参数错误请联系管理员或等待修复")
except openai.error.ServiceUnavailableError as e:
reply = handle_exception("{}API调用服务不可用:{}".format(session_name, e), "[bot]err:API调用服务不可用请重试或联系管理员或等待修复")

View File

@@ -26,6 +26,7 @@ import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models
import pkg.qqbot.ignore as ignore
import pkg.qqbot.banlist as banlist
import pkg.qqbot.blob as blob
processing = []
@@ -157,6 +158,7 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
reply[0][:min(100, len(reply[0]))] + (
"..." if len(reply[0]) > 100 else "")))
reply = [mgr.reply_filter.process(reply[0])]
reply = blob.check_text(reply[0])
else:
logging.info("回复[{}]消息".format(session_name))

47
pkg/utils/announcement.py Normal file
View File

@@ -0,0 +1,47 @@
import base64
import os
import requests
import pkg.utils.network as network
def read_latest() -> str:
resp = requests.get(
url="https://api.github.com/repos/RockChinQ/QChatGPT/contents/res/announcement",
proxies=network.wrapper_proxies()
)
obj_json = resp.json()
b64_content = obj_json["content"]
# 解码
content = base64.b64decode(b64_content).decode("utf-8")
return content
def read_saved() -> str:
# 已保存的在res/announcement_saved
# 检查是否存在
if not os.path.exists("res/announcement_saved"):
with open("res/announcement_saved", "w") as f:
f.write("")
with open("res/announcement_saved", "r") as f:
content = f.read()
return content
def write_saved(content: str):
# 已保存的在res/announcement_saved
with open("res/announcement_saved", "w") as f:
f.write(content)
def fetch_new() -> str:
latest = read_latest()
saved = read_saved()
if latest.replace(saved, "").strip() == "":
return ""
else:
write_saved(latest)
return latest.replace(saved, "").strip()

File diff suppressed because one or more lines are too long

View File

@@ -1,13 +1,19 @@
# OpenAI账号免费额度剩余查询
import requests
def fetch_credit_data(api_key: str) -> dict:
def fetch_credit_data(api_key: str, http_proxy: str) -> dict:
"""OpenAI账号免费额度剩余查询"""
proxies = {
"http":http_proxy,
"https":http_proxy
} if http_proxy is not None else None
resp = requests.get(
url="https://api.openai.com/dashboard/billing/credit_grants",
headers={
"Authorization": "Bearer {}".format(api_key),
}
},
proxies=proxies
)
return resp.json()

9
pkg/utils/network.py Normal file
View File

@@ -0,0 +1,9 @@
def wrapper_proxies() -> dict:
"""获取代理"""
import config
return {
"http": config.openai_config['proxy'],
"https": config.openai_config['proxy']
} if 'proxy' in config.openai_config and (config.openai_config['proxy'] is not None) else None

View File

@@ -6,6 +6,7 @@ import requests
import json
import pkg.utils.constants
import pkg.utils.network as network
def check_dulwich_closure():
@@ -36,7 +37,8 @@ def pull_latest(repo_path: str) -> bool:
def get_release_list() -> list:
"""获取发行列表"""
rls_list_resp = requests.get(
url="https://api.github.com/repos/RockChinQ/QChatGPT/releases"
url="https://api.github.com/repos/RockChinQ/QChatGPT/releases",
proxies=network.wrapper_proxies()
)
rls_list = rls_list_resp.json()
@@ -83,7 +85,10 @@ def update_all(cli: bool = False) -> bool:
else:
print("开始下载最新版本: {}".format(latest_rls['zipball_url']))
zip_url = latest_rls['zipball_url']
zip_resp = requests.get(url=zip_url)
zip_resp = requests.get(
url=zip_url,
proxies=network.wrapper_proxies()
)
zip_data = zip_resp.content
# 检查temp/updater目录
@@ -126,6 +131,15 @@ def update_all(cli: bool = False) -> bool:
dst = src.replace(source_root, ".")
if os.path.exists(dst):
os.remove(dst)
# 检查目标文件夹是否存在
if not os.path.exists(os.path.dirname(dst)):
os.makedirs(os.path.dirname(dst))
# 检查目标文件是否存在
if not os.path.exists(dst):
# 创建目标文件
open(dst, "w").close()
shutil.copy(src, dst)
# 把current_tag写入文件