mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-11 08:16:03 +00:00
适配线程版本
This commit is contained in:
Binary file not shown.
@@ -4,7 +4,6 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import json
|
||||
@@ -21,8 +20,6 @@ import pkg.plugin.models as plugin_models
|
||||
sessions = {}
|
||||
|
||||
|
||||
|
||||
|
||||
class SessionOfflineStatus:
|
||||
ON_GOING = 'on_going'
|
||||
EXPLICITLY_CLOSED = 'explicitly_closed'
|
||||
@@ -189,8 +186,6 @@ class Session:
|
||||
self.schedule()
|
||||
|
||||
self.response_lock = threading.Lock()
|
||||
self.bot_name = 'ai'
|
||||
self.bot_filter = None
|
||||
self.prompt = self.get_default_prompt()
|
||||
|
||||
# 设定检查session最后一次对话是否超过过期时间的计时器
|
||||
@@ -235,7 +230,7 @@ class Session:
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
|
||||
# 触发插件事件
|
||||
if self.prompt == self.get_default_prompt(get_only=True):
|
||||
if self.prompt == self.get_default_prompt(get_only = True):
|
||||
args = {
|
||||
'session_name': self.name,
|
||||
'session': self,
|
||||
@@ -264,23 +259,10 @@ class Session:
|
||||
del (res_ans_spt[0])
|
||||
res_ans = '\n\n'.join(res_ans_spt)
|
||||
|
||||
#检测是否包含ai人格否定
|
||||
logging.debug('bot_filter: {}'.format(self.bot_filter))
|
||||
if config.filter_ai_warning and self.bot_filter:
|
||||
import re
|
||||
match = re.search(self.bot_filter['reg'], res_ans)
|
||||
logging.debug(self.bot_filter)
|
||||
logging.debug(res_ans)
|
||||
if match:
|
||||
logging.debug('回复:{}, 检测到人格否定,替换中。。'.format(res_ans))
|
||||
res_ans = self.bot_filter['replace']
|
||||
logging.debug('替换为: {}'.format(res_ans))
|
||||
|
||||
# 将此次对话的双方内容加入到prompt中
|
||||
self.prompt.append({'role': 'user', 'content': text})
|
||||
self.prompt.append({'role': 'assistant', 'content': res_ans})
|
||||
|
||||
|
||||
if self.just_switched_to_exist_session:
|
||||
self.just_switched_to_exist_session = False
|
||||
self.set_ongoing()
|
||||
@@ -329,7 +311,7 @@ class Session:
|
||||
|
||||
# 持久化session
|
||||
def persistence(self):
|
||||
if self.prompt == self.get_default_prompt(get_only=True):
|
||||
if self.prompt == self.get_default_prompt(get_only = True):
|
||||
return
|
||||
|
||||
db_inst = pkg.utils.context.get_database_manager()
|
||||
|
||||
@@ -7,6 +7,8 @@ import logging
|
||||
|
||||
class ReplyFilter:
|
||||
sensitive_words = []
|
||||
mask = "*"
|
||||
mask_word = ""
|
||||
|
||||
# 默认值( 兼容性考虑 )
|
||||
baidu_check = False
|
||||
@@ -14,8 +16,10 @@ class ReplyFilter:
|
||||
baidu_secret_key = ""
|
||||
inappropriate_message_tips = "[百度云]请珍惜机器人,当前返回内容不合规"
|
||||
|
||||
def __init__(self, sensitive_words: list):
|
||||
def __init__(self, sensitive_words: list, mask: str = "*", mask_word: str = ""):
|
||||
self.sensitive_words = sensitive_words
|
||||
self.mask = mask
|
||||
self.mask_word = mask_word
|
||||
import config
|
||||
if hasattr(config, 'baidu_check') and hasattr(config, 'baidu_api_key') and hasattr(config, 'baidu_secret_key'):
|
||||
self.baidu_check = config.baidu_check
|
||||
@@ -36,7 +40,10 @@ class ReplyFilter:
|
||||
match = re.findall(word, message)
|
||||
if len(match) > 0:
|
||||
for i in range(len(match)):
|
||||
message = message.replace(match[i], "*" * len(match[i]))
|
||||
if self.mask_word == "":
|
||||
message = message.replace(match[i], self.mask * len(match[i]))
|
||||
else:
|
||||
message = message.replace(match[i], self.mask_word)
|
||||
|
||||
# 百度云审核
|
||||
if self.baidu_check:
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import mirai.models.bus
|
||||
from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
|
||||
@@ -21,12 +22,6 @@ import pkg.plugin.host as plugin_host
|
||||
import pkg.plugin.models as plugin_models
|
||||
|
||||
|
||||
# 并行运行
|
||||
def go(func, args=()):
|
||||
thread = threading.Thread(target=func, args=args, daemon=True)
|
||||
thread.start()
|
||||
|
||||
|
||||
# 检查消息是否符合泛响应匹配机制
|
||||
def check_response_rule(text: str, event):
|
||||
config = pkg.utils.context.get_config()
|
||||
@@ -41,7 +36,6 @@ def check_response_rule(text: str, event):
|
||||
import re
|
||||
if re.search(bot_name, text):
|
||||
return True, text
|
||||
|
||||
rules = config.response_rules
|
||||
# 检查前缀匹配
|
||||
if 'prefix' in rules:
|
||||
@@ -56,14 +50,33 @@ def check_response_rule(text: str, event):
|
||||
match = re.match(rule, text)
|
||||
if match:
|
||||
return True, text
|
||||
|
||||
|
||||
return False, ""
|
||||
|
||||
|
||||
def response_at():
|
||||
config = pkg.utils.context.get_config()
|
||||
if 'at' not in config.response_rules:
|
||||
return True
|
||||
|
||||
return config.response_rules['at']
|
||||
|
||||
|
||||
def random_responding():
|
||||
config = pkg.utils.context.get_config()
|
||||
if 'random_rate' in config.response_rules:
|
||||
import random
|
||||
return random.random() < config.response_rules['random_rate']
|
||||
return False
|
||||
|
||||
|
||||
# 控制QQ消息输入输出的类
|
||||
class QQBotManager:
|
||||
retry = 3
|
||||
|
||||
#线程池控制
|
||||
pool = None
|
||||
|
||||
bot: Mirai = None
|
||||
|
||||
reply_filter = None
|
||||
@@ -73,11 +86,14 @@ class QQBotManager:
|
||||
ban_person = []
|
||||
ban_group = []
|
||||
|
||||
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, first_time_init=True):
|
||||
|
||||
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, pool_num: int = 10, first_time_init=True):
|
||||
self.timeout = timeout
|
||||
self.retry = retry
|
||||
|
||||
self.pool_num = pool_num
|
||||
self.pool = ThreadPoolExecutor(max_workers=self.pool_num)
|
||||
logging.debug("Registered thread pool Size:{}".format(pool_num))
|
||||
|
||||
# 加载禁用列表
|
||||
if os.path.exists("banlist.py"):
|
||||
import banlist
|
||||
@@ -91,7 +107,12 @@ class QQBotManager:
|
||||
and config.sensitive_word_filter is not None \
|
||||
and config.sensitive_word_filter:
|
||||
with open("sensitive.json", "r", encoding="utf-8") as f:
|
||||
self.reply_filter = pkg.qqbot.filter.ReplyFilter(json.load(f)['words'])
|
||||
sensitive_json = json.load(f)
|
||||
self.reply_filter = pkg.qqbot.filter.ReplyFilter(
|
||||
sensitive_words=sensitive_json['words'],
|
||||
mask=sensitive_json['mask'] if 'mask' in sensitive_json else '*',
|
||||
mask_word=sensitive_json['mask_word'] if 'mask_word' in sensitive_json else ''
|
||||
)
|
||||
else:
|
||||
self.reply_filter = pkg.qqbot.filter.ReplyFilter([])
|
||||
|
||||
@@ -125,7 +146,7 @@ class QQBotManager:
|
||||
|
||||
self.on_person_message(event)
|
||||
|
||||
go(friend_message_handler, (event,))
|
||||
self.go(friend_message_handler, event)
|
||||
|
||||
@self.bot.on(StrangerMessage)
|
||||
async def on_stranger_message(event: StrangerMessage):
|
||||
@@ -145,7 +166,7 @@ class QQBotManager:
|
||||
|
||||
self.on_person_message(event)
|
||||
|
||||
go(stranger_message_handler, (event,))
|
||||
self.go(stranger_message_handler, event)
|
||||
|
||||
@self.bot.on(GroupMessage)
|
||||
async def on_group_message(event: GroupMessage):
|
||||
@@ -165,7 +186,7 @@ class QQBotManager:
|
||||
|
||||
self.on_group_message(event)
|
||||
|
||||
go(group_message_handler, (event,))
|
||||
self.go(group_message_handler, event)
|
||||
|
||||
def unsubscribe_all():
|
||||
"""取消所有订阅
|
||||
@@ -182,6 +203,9 @@ class QQBotManager:
|
||||
|
||||
self.unsubscribe_all = unsubscribe_all
|
||||
|
||||
def go(self, func, *args, **kwargs):
|
||||
self.pool.submit(func, *args, **kwargs)
|
||||
|
||||
def first_time_init(self, mirai_http_api_config: dict):
|
||||
"""热重载后不再运行此函数"""
|
||||
|
||||
@@ -297,14 +321,19 @@ class QQBotManager:
|
||||
|
||||
if Image in event.message_chain:
|
||||
pass
|
||||
elif At(self.bot.qq) not in event.message_chain:
|
||||
check, result = check_response_rule(str(event.message_chain).strip(), event)
|
||||
|
||||
if check:
|
||||
reply = process(result.strip())
|
||||
else:
|
||||
# 直接调用
|
||||
reply = process()
|
||||
if At(self.bot.qq) in event.message_chain and response_at():
|
||||
# 直接调用
|
||||
reply = process()
|
||||
else:
|
||||
check, result = check_response_rule(str(event.message_chain).strip(), event)
|
||||
|
||||
if check:
|
||||
reply = process(result.strip())
|
||||
# 检查是否随机响应
|
||||
elif random_responding():
|
||||
logging.info("随机响应group_{}消息".format(event.group.id))
|
||||
reply = process()
|
||||
|
||||
if reply:
|
||||
return self.send(event, reply)
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -54,7 +54,7 @@ def get_current_tag() -> str:
|
||||
return current_tag
|
||||
|
||||
|
||||
def update_all() -> bool:
|
||||
def update_all(cli: bool = False) -> bool:
|
||||
"""检查更新并下载源码"""
|
||||
current_tag = get_current_tag()
|
||||
|
||||
@@ -69,12 +69,19 @@ def update_all() -> bool:
|
||||
|
||||
if latest_rls == {}:
|
||||
latest_rls = rls
|
||||
logging.info("更新日志: {}".format(rls_notes))
|
||||
if not cli:
|
||||
logging.info("更新日志: {}".format(rls_notes))
|
||||
else:
|
||||
print("更新日志: {}".format(rls_notes))
|
||||
|
||||
if latest_rls == {}: # 没有新版本
|
||||
return False
|
||||
|
||||
# 下载最新版本的zip到temp目录
|
||||
logging.info("开始下载最新版本: {}".format(latest_rls['zipball_url']))
|
||||
if not cli:
|
||||
logging.info("开始下载最新版本: {}".format(latest_rls['zipball_url']))
|
||||
else:
|
||||
print("开始下载最新版本: {}".format(latest_rls['zipball_url']))
|
||||
zip_url = latest_rls['zipball_url']
|
||||
zip_resp = requests.get(url=zip_url)
|
||||
zip_data = zip_resp.content
|
||||
@@ -87,7 +94,10 @@ def update_all() -> bool:
|
||||
with open("temp/updater/{}.zip".format(latest_rls['tag_name']), "wb") as f:
|
||||
f.write(zip_data)
|
||||
|
||||
logging.info("下载最新版本完成: {}".format("temp/updater/{}.zip".format(latest_rls['tag_name'])))
|
||||
if not cli:
|
||||
logging.info("下载最新版本完成: {}".format("temp/updater/{}.zip".format(latest_rls['tag_name'])))
|
||||
else:
|
||||
print("下载最新版本完成: {}".format("temp/updater/{}.zip".format(latest_rls['tag_name'])))
|
||||
|
||||
# 解压zip到temp/updater/<tag_name>/
|
||||
import zipfile
|
||||
@@ -124,8 +134,11 @@ def update_all() -> bool:
|
||||
f.write(current_tag)
|
||||
|
||||
# 通知管理员
|
||||
import pkg.utils.context
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("已更新到最新版本: {}\n更新日志:\n{}\n新功能通常可以在config-template.py中看到,完整的更新日志请前往 https://github.com/RockChinQ/QChatGPT/releases 查看".format(current_tag, "\n".join(rls_notes)))
|
||||
if not cli:
|
||||
import pkg.utils.context
|
||||
pkg.utils.context.get_qqbot_manager().notify_admin("已更新到最新版本: {}\n更新日志:\n{}\n新功能通常可以在config-template.py中看到,完整的更新日志请前往 https://github.com/RockChinQ/QChatGPT/releases 查看".format(current_tag, "\n".join(rls_notes)))
|
||||
else:
|
||||
print("已更新到最新版本: {}\n更新日志:\n{}\n新功能通常可以在config-template.py中看到,完整的更新日志请前往 https://github.com/RockChinQ/QChatGPT/releases 查看".format(current_tag, "\n".join(rls_notes)))
|
||||
return True
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user