perf: 使用异步流程提高消息处理效率 #18

This commit is contained in:
Rock Chin
2022-12-12 22:04:38 +08:00
parent d6b9994c3b
commit e0ea46e893
2 changed files with 175 additions and 132 deletions
+30 -2
View File
@@ -57,6 +57,22 @@ def get_default_prompt():
config.default_prompt != "" else '' config.default_prompt != "" else ''
# def blocked_func(lock: threading.Lock):
#
# def decorator(func):
# def wrapper(*args, **kwargs):
# print('lock acquire,{}'.format(lock))
# lock.acquire()
# try:
# return func(*args, **kwargs)
# finally:
# lock.release()
#
# return wrapper
#
# return decorator
# 通用的OpenAI API交互session # 通用的OpenAI API交互session
# session内部保留了对话的上下文, # session内部保留了对话的上下文,
# 收到用户消息后,将上下文提交给OpenAI API生成回复 # 收到用户消息后,将上下文提交给OpenAI API生成回复
@@ -74,6 +90,16 @@ class Session:
just_switched_to_exist_session = False just_switched_to_exist_session = False
response_lock = threading.Lock()
# 加锁
def acquire_response_lock(self):
self.response_lock.acquire()
# 释放锁
def release_response_lock(self):
self.response_lock.release()
def __init__(self, name: str): def __init__(self, name: str):
self.name = name self.name = name
self.create_timestamp = int(time.time()) self.create_timestamp = int(time.time())
@@ -188,6 +214,8 @@ class Session:
self.last_interact_timestamp = int(time.time()) self.last_interact_timestamp = int(time.time())
self.just_switched_to_exist_session = False self.just_switched_to_exist_session = False
self.response_lock = threading.Lock()
if schedule_new: if schedule_new:
self.schedule() self.schedule()
@@ -207,7 +235,7 @@ class Session:
self.last_interact_timestamp = last_one['last_interact_timestamp'] self.last_interact_timestamp = last_one['last_interact_timestamp']
self.prompt = last_one['prompt'] self.prompt = last_one['prompt']
just_switched = True self.just_switched_to_exist_session = True
return self return self
# 切换到下一个session # 切换到下一个session
@@ -222,7 +250,7 @@ class Session:
self.last_interact_timestamp = next_one['last_interact_timestamp'] self.last_interact_timestamp = next_one['last_interact_timestamp']
self.prompt = next_one['prompt'] self.prompt = next_one['prompt']
just_switched = True self.just_switched_to_exist_session = True
return self return self
def list_history(self, capacity: int = 10, page: int = 0): def list_history(self, capacity: int = 10, page: int = 0):
+35 -20
View File
@@ -21,6 +21,12 @@ inst = None
processing = [] processing = []
# 并行运行
def go(func, args=()):
thread = threading.Thread(target=func, args=args, daemon=True)
thread.start()
# 控制QQ消息输入输出的类 # 控制QQ消息输入输出的类
class QQBotManager: class QQBotManager:
timeout = 60 timeout = 60
@@ -54,15 +60,15 @@ class QQBotManager:
@bot.on(FriendMessage) @bot.on(FriendMessage)
async def on_friend_message(event: FriendMessage): async def on_friend_message(event: FriendMessage):
return await self.on_person_message(event) go(self.on_person_message, (event,))
@bot.on(StrangerMessage) @bot.on(StrangerMessage)
async def on_stranger_message(event: StrangerMessage): async def on_stranger_message(event: StrangerMessage):
return await self.on_person_message(event) go(self.on_person_message, (event,))
@bot.on(GroupMessage) @bot.on(GroupMessage)
async def on_group_message(event: GroupMessage): async def on_group_message(event: GroupMessage):
return await self.on_group_message(event) go(self.on_group_message, (event,))
self.bot = bot self.bot = bot
@@ -72,9 +78,20 @@ class QQBotManager:
# 统一的消息处理函数 # 统一的消息处理函数
@func_set_timeout(timeout) @func_set_timeout(timeout)
def process_message(self, launcher_type: str, launcher_id: int, text_message: str) -> str: def process_message(self, launcher_type: str, launcher_id: int, text_message: str) -> str:
global processing
reply = '' reply = ''
session_name = "{}_{}".format(launcher_type, launcher_id) session_name = "{}_{}".format(launcher_type, launcher_id)
pkg.openai.session.get_session(session_name).acquire_response_lock()
try:
if session_name in processing:
return "[bot]err:正在处理中,请稍后再试"
processing.append(session_name)
try:
if text_message.startswith('!') or text_message.startswith(""): # 指令 if text_message.startswith('!') or text_message.startswith(""): # 指令
try: try:
logging.info("[{}]发起指令:{}".format(session_name, text_message[:min(20, len(text_message))] + ( logging.info("[{}]发起指令:{}".format(session_name, text_message[:min(20, len(text_message))] + (
@@ -163,7 +180,8 @@ class QQBotManager:
reply = "[bot]err:API调用额度超额,请联系作者,或等待修复" reply = "[bot]err:API调用额度超额,请联系作者,或等待修复"
except openai.error.InvalidRequestError as e: except openai.error.InvalidRequestError as e:
self.notify_admin("{}API调用参数错误:{}\n\n这可能是由于config.py中的prompt_submit_length参数或" self.notify_admin("{}API调用参数错误:{}\n\n这可能是由于config.py中的prompt_submit_length参数或"
"completion_api_params中的max_tokens参数数值过大导致的,请尝试将其降低".format(session_name, e)) "completion_api_params中的max_tokens参数数值过大导致的,请尝试将其降低".format(
session_name, e))
reply = "[bot]err:API调用参数错误,请联系作者,或等待修复" reply = "[bot]err:API调用参数错误,请联系作者,或等待修复"
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
@@ -172,13 +190,20 @@ class QQBotManager:
logging.info( logging.info(
"回复[{}]消息:{}".format(session_name, reply[:min(100, len(reply))] + ("..." if len(reply) > 100 else ""))) "回复[{}]消息:{}".format(session_name, reply[:min(100, len(reply))] + ("..." if len(reply) > 100 else "")))
reply = self.reply_filter.process(reply) reply = self.reply_filter.process(reply)
finally:
processing.remove(session_name)
finally:
pkg.openai.session.get_session(session_name).release_response_lock()
return reply return reply
def send(self, event, msg):
asyncio.run(self.bot.send(event, msg))
# 私聊消息处理 # 私聊消息处理
async def on_person_message(self, event: MessageEvent): def on_person_message(self, event: MessageEvent):
global processing global processing
if "person_{}".format(event.sender.id) in processing:
return await self.bot.send(event, "err:正在处理中,请稍后再试")
reply = '' reply = ''
@@ -188,9 +213,6 @@ class QQBotManager:
if Image in event.message_chain: if Image in event.message_chain:
pass pass
else: else:
processing.append("person_{}".format(event.sender.id))
try:
# 超时则重试,重试超过次数则放弃 # 超时则重试,重试超过次数则放弃
failed = 0 failed = 0
for i in range(self.retry): for i in range(self.retry):
@@ -203,17 +225,13 @@ class QQBotManager:
if failed == self.retry: if failed == self.retry:
reply = "[bot]err:请求超时" reply = "[bot]err:请求超时"
finally:
processing.remove("person_{}".format(event.sender.id))
if reply != '': if reply != '':
return await self.bot.send(event, reply) return self.send(event, reply)
# 群消息处理 # 群消息处理
async def on_group_message(self, event: GroupMessage): def on_group_message(self, event: GroupMessage):
global processing global processing
if "group_{}".format(event.group.id) in processing:
return await self.bot.send(event, "err:正在处理中,请稍后再试")
reply = '' reply = ''
@@ -226,7 +244,6 @@ class QQBotManager:
processing.append("group_{}".format(event.sender.id)) processing.append("group_{}".format(event.sender.id))
try:
# 超时则重试,重试超过次数则放弃 # 超时则重试,重试超过次数则放弃
failed = 0 failed = 0
for i in range(self.retry): for i in range(self.retry):
@@ -239,11 +256,9 @@ class QQBotManager:
if failed == self.retry: if failed == self.retry:
reply = "err:请求超时" reply = "err:请求超时"
finally:
processing.remove("group_{}".format(event.sender.id))
if reply != '': if reply != '':
return await self.bot.send(event, reply) return self.send(event, reply)
# 通知系统管理员 # 通知系统管理员
def notify_admin(self, message: str): def notify_admin(self, message: str):