Merge pull request #673 from RockChinQ/refactor/asyncio/control-flow

Refactor: 请求处理控制流
This commit is contained in:
Junyan Qin
2024-01-28 18:41:59 +08:00
committed by GitHub
134 changed files with 2828 additions and 3268 deletions

View File

@@ -1,58 +0,0 @@
name: Update cmdpriv-template
on:
push:
paths:
- 'pkg/qqbot/cmds/**'
pull_request:
types: [closed]
paths:
- 'pkg/qqbot/cmds/**'
jobs:
update-cmdpriv-template:
if: github.event.pull_request.merged == true || github.event_name == 'push'
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.10.13
- name: Install dependencies
run: |
python -m pip install --upgrade yiri-mirai-rc openai>=1.0.0 colorlog func_timeout dulwich Pillow CallingGPT tiktoken
python -m pip install -U openai>=1.0.0
- name: Copy Scripts
run: |
cp res/scripts/generate_cmdpriv_template.py .
- name: Generate Files
run: |
python main.py
- name: Run generate_cmdpriv_template.py
run: python3 generate_cmdpriv_template.py
- name: Check for changes in cmdpriv-template.json
id: check_changes
run: |
if git diff --name-only | grep -q "res/templates/cmdpriv-template.json"; then
echo "::set-output name=changes_detected::true"
else
echo "::set-output name=changes_detected::false"
fi
- name: Commit changes to cmdpriv-template.json
if: steps.check_changes.outputs.changes_detected == 'true'
run: |
git config --global user.name "GitHub Actions Bot"
git config --global user.email "<github-actions@github.com>"
git add res/templates/cmdpriv-template.json
git commit -m "Update cmdpriv-template.json"
git push

View File

@@ -1,38 +0,0 @@
from __future__ import annotations
import logging
from ..qqbot import manager as qqbot_mgr
from ..openai import manager as openai_mgr
from ..config import manager as config_mgr
from ..database import manager as database_mgr
from ..utils.center import v2 as center_mgr
from ..plugin import host as plugin_host
class Application:
im_mgr: qqbot_mgr.QQBotManager = None
llm_mgr: openai_mgr.OpenAIInteract = None
cfg_mgr: config_mgr.ConfigManager = None
tips_mgr: config_mgr.ConfigManager = None
db_mgr: database_mgr.DatabaseManager = None
ctr_mgr: center_mgr.V2CenterAPI = None
logger: logging.Logger = None
def __init__(self):
pass
async def initialize(self):
await self.im_mgr.initialize()
async def run(self):
# TODO make it async
plugin_host.initialize_plugins()
await self.im_mgr.run()

View File

@@ -1,54 +0,0 @@
import logging
import os
import sys
import time
import colorlog
log_colors_config = {
'DEBUG': 'green', # cyan white
'INFO': 'white',
'WARNING': 'yellow',
'ERROR': 'red',
'CRITICAL': 'cyan',
}
async def init_logging() -> logging.Logger:
level = logging.INFO
if 'DEBUG' in os.environ and os.environ['DEBUG'] in ['true', '1']:
level = logging.DEBUG
log_file_name = "logs/qcg-%s.log" % time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
qcg_logger = logging.getLogger("qcg")
qcg_logger.setLevel(level)
log_handlers: logging.Handler = [
logging.StreamHandler(sys.stdout),
logging.FileHandler(log_file_name)
]
for handler in log_handlers:
handler.setLevel(level)
handler.setFormatter(
colorlog.ColoredFormatter(
fmt="[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
log_colors=log_colors_config
)
)
qcg_logger.addHandler(handler)
logging.basicConfig(level=level, # 设置日志输出格式
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
# 日志输出的格式
# -8表示占位符让输出左对齐输出长度都为8位
datefmt="%Y-%m-%d %H:%M:%S" # 时间输出的格式
)
return qcg_logger

108
pkg/command/cmdmgr.py Normal file
View File

@@ -0,0 +1,108 @@
from __future__ import annotations
import typing
from ..core import app, entities as core_entities
from ..openai import entities as llm_entities
from ..openai.session import entities as session_entities
from . import entities, operator, errors
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cfg, cmd, help, version, update
class CommandManager:
"""命令管理器
"""
ap: app.Application
cmd_list: list[operator.CommandOperator]
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
# 实例化所有类
self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators]
# 设置所有类的子节点
for cmd in self.cmd_list:
cmd.children = [child for child in self.cmd_list if child.parent_class == cmd.__class__]
# 初始化所有类
for cmd in self.cmd_list:
await cmd.initialize()
async def _execute(
self,
context: entities.ExecuteContext,
operator_list: list[operator.CommandOperator],
operator: operator.CommandOperator = None
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令
"""
found = False
if len(context.crt_params) > 0:
for oper in operator_list:
if (context.crt_params[0] == oper.name \
or context.crt_params[0] in oper.alias) \
and (oper.parent_class is None or oper.parent_class == operator.__class__):
found = True
context.crt_command = context.crt_params[0]
context.crt_params = context.crt_params[1:]
async for ret in self._execute(
context,
oper.children,
oper
):
yield ret
break
if not found:
if operator is None:
yield entities.CommandReturn(
error=errors.CommandNotFoundError(context.crt_params[0])
)
else:
if operator.lowest_privilege > context.privilege:
yield entities.CommandReturn(
error=errors.CommandPrivilegeError(operator.name)
)
else:
async for ret in operator.execute(context):
yield ret
async def execute(
self,
command_text: str,
query: core_entities.Query,
session: session_entities.Session
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令
"""
privilege = 1
if query.sender_id == self.ap.cfg_mgr.data['admin_qq'] \
or query.sender_id in self.ap.cfg_mgr['admin_qq']:
privilege = 2
ctx = entities.ExecuteContext(
query=query,
session=session,
command_text=command_text,
command='',
crt_command='',
params=command_text.split(' '),
crt_params=command_text.split(' '),
privilege=privilege
)
async for ret in self._execute(
ctx,
self.cmd_list
):
yield ret

43
pkg/command/entities.py Normal file
View File

@@ -0,0 +1,43 @@
from __future__ import annotations
import typing
import pydantic
import mirai
from ..core import app, entities as core_entities
from ..openai.session import entities as session_entities
from . import errors, operator
class CommandReturn(pydantic.BaseModel):
text: typing.Optional[str]
"""文本
"""
image: typing.Optional[mirai.Image]
error: typing.Optional[errors.CommandError]= None
class Config:
arbitrary_types_allowed = True
class ExecuteContext(pydantic.BaseModel):
query: core_entities.Query
session: session_entities.Session
command_text: str
command: str
crt_command: str
params: list[str]
crt_params: list[str]
privilege: int

33
pkg/command/errors.py Normal file
View File

@@ -0,0 +1,33 @@
class CommandError(Exception):
def __init__(self, message: str = None):
self.message = message
def __str__(self):
return self.message
class CommandNotFoundError(CommandError):
def __init__(self, message: str = None):
super().__init__("未知命令: "+message)
class CommandPrivilegeError(CommandError):
def __init__(self, message: str = None):
super().__init__("权限不足: "+message)
class ParamNotEnoughError(CommandError):
def __init__(self, message: str = None):
super().__init__("参数不足: "+message)
class CommandOperationError(CommandError):
def __init__(self, message: str = None):
super().__init__("操作失败: "+message)

75
pkg/command/operator.py Normal file
View File

@@ -0,0 +1,75 @@
from __future__ import annotations
import typing
import abc
from ..core import app, entities as core_entities
from ..openai.session import entities as session_entities
from . import entities
preregistered_operators: list[typing.Type[CommandOperator]] = []
def operator_class(
name: str,
help: str,
usage: str = None,
alias: list[str] = [],
privilege: int=1, # 1为普通用户2为管理员
parent_class: typing.Type[CommandOperator] = None
) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]:
def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]:
cls.name = name
cls.alias = alias
cls.help = help
cls.usage = usage
cls.parent_class = parent_class
preregistered_operators.append(cls)
return cls
return decorator
class CommandOperator(metaclass=abc.ABCMeta):
"""命令算子
"""
ap: app.Application
name: str
"""名称,搜索到时若符合则使用"""
alias: list[str]
"""同name"""
help: str
"""此节点的帮助信息"""
usage: str = None
parent_class: typing.Type[CommandOperator] | None = None
"""父节点类。标记以供管理器在初始化时编织父子关系。"""
lowest_privilege: int = 0
"""最低权限。若权限低于此值,则不予执行。"""
children: list[CommandOperator]
"""子节点。解析命令时,若节点有子节点,则以下一个参数去匹配子节点,
若有匹配中的,转移到子节点中执行,若没有匹配中的或没有子节点,执行此节点。"""
def __init__(self, ap: app.Application):
self.ap = ap
self.children = []
async def initialize(self):
pass
@abc.abstractmethod
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
pass

View File

@@ -0,0 +1,98 @@
from __future__ import annotations
import typing
import json
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="cfg",
help="配置项管理",
usage='!cfg <配置项> [配置值]\n!cfg all',
privilege=2
)
class CfgOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行
"""
reply = ''
params = context.crt_params
cfg_mgr = self.ap.cfg_mgr
false = False
true = True
reply_str = ""
if len(params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供配置项名称'))
else:
cfg_name = params[0]
if cfg_name == 'all':
reply_str = "[bot]所有配置项:\n\n"
for cfg in cfg_mgr.data.keys():
if not cfg.startswith('__') and not cfg == 'logging':
# 根据配置项类型进行格式化如果是字典则转换为json并格式化
if isinstance(cfg_mgr.data[cfg], str):
reply_str += "{}: \"{}\"\n".format(cfg, cfg_mgr.data[cfg])
elif isinstance(cfg_mgr.data[cfg], dict):
# 不进行unicode转义并格式化
reply_str += "{}: {}\n".format(cfg,
json.dumps(cfg_mgr.data[cfg],
ensure_ascii=False, indent=4))
else:
reply_str += "{}: {}\n".format(cfg, cfg_mgr.data[cfg])
yield entities.CommandReturn(text=reply_str)
else:
cfg_entry_path = cfg_name.split('.')
try:
if len(params) == 1: # 未指定配置值,返回配置项值
cfg_entry = cfg_mgr.data[cfg_entry_path[0]]
if len(cfg_entry_path) > 1:
for i in range(1, len(cfg_entry_path)):
cfg_entry = cfg_entry[cfg_entry_path[i]]
if isinstance(cfg_entry, str):
reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, cfg_entry)
elif isinstance(cfg_entry, dict):
reply_str = "[bot]配置项{}: {}\n".format(cfg_name,
json.dumps(cfg_entry,
ensure_ascii=False, indent=4))
else:
reply_str = "[bot]配置项{}: {}\n".format(cfg_name, cfg_entry)
yield entities.CommandReturn(text=reply_str)
else:
cfg_value = " ".join(params[1:])
cfg_value = eval(cfg_value)
cfg_entry = cfg_mgr.data[cfg_entry_path[0]]
if len(cfg_entry_path) > 1:
for i in range(1, len(cfg_entry_path) - 1):
cfg_entry = cfg_entry[cfg_entry_path[i]]
if isinstance(cfg_entry[cfg_entry_path[-1]], type(cfg_value)):
cfg_entry[cfg_entry_path[-1]] = cfg_value
yield entities.CommandReturn(text="配置项{}修改成功".format(cfg_name))
else:
# reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)]
yield entities.CommandReturn(error=errors.CommandOperationError("配置项{}类型不匹配".format(cfg_name)))
else:
cfg_mgr.data[cfg_entry_path[0]] = cfg_value
# reply = ["[bot]配置项{}修改成功".format(cfg_name)]
yield entities.CommandReturn(text="配置项{}修改成功".format(cfg_name))
except KeyError:
# reply = ["[bot]err:未找到配置项 {}".format(cfg_name)]
yield entities.CommandReturn(error=errors.CommandOperationError("未找到配置项 {}".format(cfg_name)))
except NameError:
# reply = ["[bot]err:值{}不合法(字符串需要使用双引号包裹)".format(cfg_value)]
yield entities.CommandReturn(error=errors.CommandOperationError("{}不合法(字符串需要使用双引号包裹)".format(cfg_value)))
except ValueError:
# reply = ["[bot]err:未找到配置项 {}".format(cfg_name)]
yield entities.CommandReturn(error=errors.CommandOperationError("未找到配置项 {}".format(cfg_name)))

View File

@@ -0,0 +1,50 @@
from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="cmd",
help='显示命令列表',
usage='!cmd\n!cmd <命令名称>'
)
class CmdOperator(operator.CommandOperator):
"""命令列表
"""
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行
"""
if len(context.crt_params) == 0:
reply_str = "当前所有命令: \n\n"
for cmd in self.ap.cmd_mgr.cmd_list:
if cmd.parent_class is None:
reply_str += f"{cmd.name}: {cmd.help}\n"
reply_str += "\n使用 !cmd <命令名称> 查看命令的详细帮助"
yield entities.CommandReturn(text=reply_str.strip())
else:
cmd_name = context.crt_params[0]
cmd = None
for _cmd in self.ap.cmd_mgr.cmd_list:
if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (_cmd.parent_class is None):
cmd = _cmd
break
if cmd is None:
yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name))
else:
reply_str = f"{cmd.name}: {cmd.help}\n\n"
reply_str += f"使用方法: \n{cmd.usage}"
yield entities.CommandReturn(text=reply_str.strip())

View File

@@ -0,0 +1,62 @@
from __future__ import annotations
import typing
import traceback
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="default",
help="操作情景预设",
usage='!default\n!default set <指定情景预设为默认>'
)
class DefaultOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
reply_str = "当前所有情景预设: \n\n"
for prompt in self.ap.prompt_mgr.get_all_prompts():
content = ""
for msg in prompt.messages:
content += f" {msg.role}: {msg.content}"
reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n"
reply_str += f"当前会话使用的是: {context.session.use_prompt_name}"
yield entities.CommandReturn(text=reply_str.strip())
@operator.operator_class(
name="set",
help="设置当前会话默认情景预设",
parent_class=DefaultOperator
)
class DefaultSetOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称'))
else:
prompt_name = context.crt_params[0]
try:
prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name)
if prompt is None:
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name)))
else:
context.session.use_prompt_name = prompt.name
yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效")
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e)))

View File

@@ -0,0 +1,62 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="del",
help="删除当前会话的历史记录",
usage='!del <序号>\n!del all'
)
class DelOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
delete_index = 0
if len(context.crt_params) > 0:
try:
delete_index = int(context.crt_params[0])
except:
yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数'))
return
if delete_index < 0 or delete_index >= len(context.session.conversations):
yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围'))
return
# 倒序
to_delete_index = len(context.session.conversations)-1-delete_index
if context.session.conversations[to_delete_index] == context.session.using_conversation:
context.session.using_conversation = None
del context.session.conversations[to_delete_index]
yield entities.CommandReturn(text=f"已删除对话: {delete_index}")
else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
@operator.operator_class(
name="all",
help="删除此会话的所有历史记录",
parent_class=DelOperator
)
class DelAllOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
context.session.conversations = []
context.session.using_conversation = None
yield entities.CommandReturn(text="已删除所有对话")

View File

@@ -0,0 +1,25 @@
from __future__ import annotations
from typing import AsyncGenerator
from .. import operator, entities, cmdmgr
from ...plugin import host as plugin_host
@operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func')
class FuncOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> AsyncGenerator[entities.CommandReturn, None]:
reply_str = "当前已加载的内容函数: \n\n"
index = 1
for func in self.ap.tool_mgr.all_functions:
reply_str += "{}. {}{}:\n{}\n\n".format(
index,
("(已禁用) " if not func.enable else ""),
func.name,
func.description,
)
index += 1
yield entities.CommandReturn(text=reply_str)

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name='help',
help='显示帮助',
usage='!help\n!help <命令名称>'
)
class HelpOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
help = self.ap.tips_mgr.data['help_message']
help += '\n发送命令 !cmd 可查看命令列表'
yield entities.CommandReturn(text=help)

View File

@@ -0,0 +1,36 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="last",
help="切换到前一个对话",
usage='!last'
)
class LastOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
# 找到当前会话的上一个会话
for index in range(len(context.session.conversations)-1, -1, -1):
if context.session.conversations[index] == context.session.using_conversation:
if index == 0:
yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了'))
return
else:
context.session.using_conversation = context.session.conversations[index-1]
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}")
return
else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="list",
help="列出此会话中的所有历史对话",
usage='!list\n!list <页码>'
)
class ListOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
page = 0
if len(context.crt_params) > 0:
try:
page = int(context.crt_params[0]-1)
except:
yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数'))
return
record_per_page = 10
content = ''
index = 0
using_conv_index = 0
for conv in context.session.conversations[::-1]:
time_str = conv.create_time.strftime("%Y-%m-%d %H:%M:%S")
if conv == context.session.using_conversation:
using_conv_index = index
if index >= page * record_per_page and index < (page + 1) * record_per_page:
content += f"{index} {time_str}: {conv.messages[0].content}\n"
index += 1
if content == '':
content = ''
else:
if context.session.using_conversation is None:
content += "\n当前处于新会话"
else:
content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].content}"
yield entities.CommandReturn(text=f"{page + 1} 页 (时间倒序):\n{content}")

View File

@@ -0,0 +1,35 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="next",
help="切换到后一个对话",
usage='!next'
)
class NextOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
# 找到当前会话的下一个会话
for index in range(len(context.session.conversations)):
if context.session.conversations[index] == context.session.using_conversation:
if index == len(context.session.conversations)-1:
yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了'))
return
else:
context.session.using_conversation = context.session.conversations[index+1]
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
yield entities.CommandReturn(text=f"已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}")
return
else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))

View File

@@ -0,0 +1,239 @@
from __future__ import annotations
import typing
import traceback
from .. import operator, entities, cmdmgr, errors
from ...plugin import host as plugin_host
from ...utils import updater
from ...core import app
@operator.operator_class(
name="plugin",
help="插件操作",
usage="!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>"
)
class PluginOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
plugin_list = plugin_host.__plugins__
reply_str = "所有插件({}):\n".format(len(plugin_host.__plugins__))
idx = 0
for key in plugin_host.iter_plugins_name():
plugin = plugin_list[key]
reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\
.format((idx+1), plugin['name'],
"[已禁用]" if not plugin['enabled'] else "",
plugin['description'],
plugin['version'], plugin['author'])
# TODO 从元数据调远程地址
# if updater.is_repo("/".join(plugin['path'].split('/')[:-1])):
# remote_url = updater.get_remote_url("/".join(plugin['path'].split('/')[:-1]))
# if remote_url != "https://github.com/RockChinQ/QChatGPT" and remote_url != "https://gitee.com/RockChin/QChatGPT":
# reply_str += "源码: "+remote_url+"\n"
idx += 1
yield entities.CommandReturn(text=reply_str)
@operator.operator_class(
name="get",
help="安装插件",
privilege=2,
parent_class=PluginOperator
)
class PluginGetOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址'))
else:
repo = context.crt_params[0]
yield entities.CommandReturn(text="正在安装插件...")
try:
plugin_host.install_plugin(repo)
yield entities.CommandReturn(text="插件安装成功,请重启程序以加载插件")
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件安装失败: "+str(e)))
@operator.operator_class(
name="update",
help="更新插件",
privilege=2,
parent_class=PluginOperator
)
class PluginUpdateOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
try:
plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(plugin_name)
if plugin_path_name is not None:
yield entities.CommandReturn(text="正在更新插件...")
plugin_host.update_plugin(plugin_name)
yield entities.CommandReturn(text="插件更新成功,请重启程序以加载插件")
else:
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: 未找到插件"))
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
@operator.operator_class(
name="all",
help="更新所有插件",
privilege=2,
parent_class=PluginUpdateOperator
)
class PluginUpdateAllOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
plugins = []
for key in plugin_host.__plugins__:
plugins.append(key)
if plugins:
yield entities.CommandReturn(text="正在更新插件...")
updated = []
try:
for plugin_name in plugins:
plugin_host.update_plugin(plugin_name)
updated.append(plugin_name)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
yield entities.CommandReturn(text="已更新插件: {}".format(", ".join(updated)))
else:
yield entities.CommandReturn(text="没有可更新的插件")
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
@operator.operator_class(
name="del",
help="删除插件",
privilege=2,
parent_class=PluginOperator
)
class PluginDelOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
try:
plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(plugin_name)
if plugin_path_name is not None:
yield entities.CommandReturn(text="正在删除插件...")
plugin_host.uninstall_plugin(plugin_name)
yield entities.CommandReturn(text="插件删除成功,请重启程序以加载插件")
else:
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: 未找到插件"))
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: "+str(e)))
def update_plugin_status(plugin_name: str, new_status: bool, ap: app.Application):
if plugin_name in plugin_host.__plugins__:
plugin_host.__plugins__[plugin_name]['enabled'] = new_status
for func in ap.tool_mgr.all_functions:
if func.name.startswith(plugin_name+'-'):
func.enable = new_status
return True
else:
return False
@operator.operator_class(
name="on",
help="启用插件",
privilege=2,
parent_class=PluginOperator
)
class PluginEnableOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
try:
if update_plugin_status(plugin_name, True, self.ap):
yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name))
else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e)))
@operator.operator_class(
name="off",
help="禁用插件",
privilege=2,
parent_class=PluginOperator
)
class PluginDisableOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
try:
if update_plugin_status(plugin_name, False, self.ap):
yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name))
else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e)))

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="prompt",
help="查看当前对话的前文",
usage='!prompt'
)
class PromptOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行
"""
if context.session.using_conversation is None:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
else:
reply_str = '当前对话所有内容:\n\n'
for msg in context.session.using_conversation.messages:
reply_str += f"{msg.role}: {msg.content}\n"
yield entities.CommandReturn(text=reply_str)

View File

@@ -0,0 +1,34 @@
from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="resend",
help="重发当前会话的最后一条消息",
usage='!resend'
)
class ResendOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
# 回滚到最后一条用户message前
if context.session.using_conversation is None:
yield entities.CommandReturn(error=errors.CommandError("当前没有对话"))
else:
conv_msg = context.session.using_conversation.messages
# 倒序一直删到最后一条用户message
while len(conv_msg) > 0 and conv_msg[-1].role != 'user':
conv_msg.pop()
if len(conv_msg) > 0:
# 删除最后一条用户message
conv_msg.pop()
# 不重发了,提示用户已删除就行了
yield entities.CommandReturn(text="已删除最后一次请求记录")

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="reset",
help="重置当前会话",
usage='!reset'
)
class ResetOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行
"""
context.session.using_conversation = None
yield entities.CommandReturn(text="已重置当前会话")

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
import typing
import traceback
from .. import operator, entities, cmdmgr, errors
from ...utils import updater
@operator.operator_class(
name="update",
help="更新程序",
usage='!update',
privilege=2
)
class UpdateCommand(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
yield entities.CommandReturn(text="正在进行更新...")
if updater.update_all():
yield entities.CommandReturn(text="更新完成,请重启程序以应用更新")
else:
yield entities.CommandReturn(text="当前已是最新版本")
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("更新失败: "+str(e)))

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
import typing
from .. import operator, cmdmgr, entities, errors
from ...utils import updater
@operator.operator_class(
name="version",
help="显示版本信息",
usage='!version'
)
class VersionCommand(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
reply_str = f"当前版本: \n{updater.get_current_version_info()}"
try:
if updater.is_new_version_available():
reply_str += "\n\n有新版本可用, 使用 !update 更新"
except:
pass
yield entities.CommandReturn(text=reply_str.strip())

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from . import model as file_model
from ..utils import context
from .impls import pymodule, json as json_file

71
pkg/core/app.py Normal file
View File

@@ -0,0 +1,71 @@
from __future__ import annotations
import logging
import asyncio
from ..qqbot import manager as qqbot_mgr
from ..openai.session import sessionmgr as llm_session_mgr
from ..openai.requester import modelmgr as llm_model_mgr
from ..openai.sysprompt import sysprompt as llm_prompt_mgr
from ..openai.tools import toolmgr as llm_tool_mgr
from ..config import manager as config_mgr
from ..database import manager as database_mgr
from ..utils.center import v2 as center_mgr
from ..command import cmdmgr
from ..plugin import host as plugin_host
from . import pool, controller
from ..pipeline import stagemgr
class Application:
im_mgr: qqbot_mgr.QQBotManager = None
cmd_mgr: cmdmgr.CommandManager = None
sess_mgr: llm_session_mgr.SessionManager = None
model_mgr: llm_model_mgr.ModelManager = None
prompt_mgr: llm_prompt_mgr.PromptManager = None
tool_mgr: llm_tool_mgr.ToolManager = None
cfg_mgr: config_mgr.ConfigManager = None
tips_mgr: config_mgr.ConfigManager = None
db_mgr: database_mgr.DatabaseManager = None
ctr_mgr: center_mgr.V2CenterAPI = None
query_pool: pool.QueryPool = None
ctrl: controller.Controller = None
stage_mgr: stagemgr.StageManager = None
logger: logging.Logger = None
def __init__(self):
pass
async def initialize(self):
plugin_host.initialize_plugins()
# 把现有的所有内容函数加到toolmgr里
for func in plugin_host.__callable_functions__:
self.tool_mgr.register_legacy_function(
name=func['name'],
description=func['description'],
parameters=func['parameters'],
func=plugin_host.__function_inst_map__[func['name']]
)
async def run(self):
tasks = [
asyncio.create_task(self.im_mgr.run()),
asyncio.create_task(self.ctrl.run())
]
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

View File

@@ -3,19 +3,23 @@ from __future__ import print_function
import os
import sys
from . import files
from . import deps
from . import log
from . import config
from .bootutils import files
from .bootutils import deps
from .bootutils import log
from .bootutils import config
from . import app
from . import pool
from . import controller
from ..pipeline import stagemgr
from ..audit import identifier
from ..database import manager as db_mgr
from ..openai import manager as llm_mgr
from ..openai import session as llm_session
from ..openai import dprompt as llm_dprompt
from ..openai.session import sessionmgr as llm_session_mgr
from ..openai.requester import modelmgr as llm_model_mgr
from ..openai.sysprompt import sysprompt as llm_prompt_mgr
from ..openai.tools import toolmgr as llm_tool_mgr
from ..qqbot import manager as im_mgr
from ..qqbot.cmds import aamgr as im_cmd_aamgr
from ..command import cmdmgr
from ..plugin import host as plugin_host
from ..utils.center import v2 as center_v2
from ..utils import updater
@@ -75,17 +79,14 @@ async def make_app() -> app.Application:
if cfg_mgr.data['admin_qq'] == 0:
qcg_logger.warning("未设置管理员QQ号将无法使用管理员命令请在 config.py 中修改 admin_qq")
# TODO make it async
llm_dprompt.register_all()
im_cmd_aamgr.register_all()
im_cmd_aamgr.apply_privileges()
# 构建组建实例
ap = app.Application()
ap.logger = qcg_logger
ap.cfg_mgr = cfg_mgr
ap.tips_mgr = tips_mgr
ap.query_pool = pool.QueryPool()
center_v2_api = center_v2.V2CenterAPI(
basic_info={
"host_id": identifier.identifier['host_id'],
@@ -105,22 +106,45 @@ async def make_app() -> app.Application:
db_mgr_inst.initialize_database()
ap.db_mgr = db_mgr_inst
llm_mgr_inst = llm_mgr.OpenAIInteract(ap)
ap.llm_mgr = llm_mgr_inst
# TODO make it async
llm_session.load_sessions()
cmd_mgr_inst = cmdmgr.CommandManager(ap)
await cmd_mgr_inst.initialize()
ap.cmd_mgr = cmd_mgr_inst
im_mgr_inst = im_mgr.QQBotManager(first_time_init=True, ap=ap)
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
await llm_model_mgr_inst.initialize()
ap.model_mgr = llm_model_mgr_inst
llm_session_mgr_inst = llm_session_mgr.SessionManager(ap)
await llm_session_mgr_inst.initialize()
ap.sess_mgr = llm_session_mgr_inst
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap)
await llm_prompt_mgr_inst.initialize()
ap.prompt_mgr = llm_prompt_mgr_inst
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
await llm_tool_mgr_inst.initialize()
ap.tool_mgr = llm_tool_mgr_inst
im_mgr_inst = im_mgr.QQBotManager(ap=ap)
await im_mgr_inst.initialize()
ap.im_mgr = im_mgr_inst
stage_mgr = stagemgr.StageManager(ap)
await stage_mgr.initialize()
ap.stage_mgr = stage_mgr
ctrl = controller.Controller(ap)
ap.ctrl = ctrl
# TODO make it async
plugin_host.load_plugins()
# plugin_host.initialize_plugins()
await ap.initialize()
return ap
async def main():
app_inst = await make_app()
await app_inst.initialize()
await app_inst.run()

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import json
from ..config import manager as config_mgr
from ..config.impls import pymodule
from ...config import manager as config_mgr
from ...config.impls import pymodule
load_python_module_config = config_mgr.load_python_module_config

56
pkg/core/bootutils/log.py Normal file
View File

@@ -0,0 +1,56 @@
import logging
import os
import sys
import time
import colorlog
log_colors_config = {
"DEBUG": "green", # cyan white
"INFO": "white",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "cyan",
}
async def init_logging() -> logging.Logger:
level = logging.INFO
if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]:
level = logging.DEBUG
log_file_name = "logs/qcg-%s.log" % time.strftime(
"%Y-%m-%d-%H-%M-%S", time.localtime()
)
qcg_logger = logging.getLogger("qcg")
qcg_logger.setLevel(level)
color_formatter = colorlog.ColoredFormatter(
fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
log_colors=log_colors_config,
)
stream_handler = logging.StreamHandler(sys.stdout)
log_handlers: logging.Handler = [stream_handler, logging.FileHandler(log_file_name)]
for handler in log_handlers:
handler.setLevel(level)
handler.setFormatter(color_formatter)
qcg_logger.addHandler(handler)
logging.basicConfig(
level=logging.INFO, # 设置日志输出格式
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
# 日志输出的格式
# -8表示占位符让输出左对齐输出长度都为8位
datefmt="%Y-%m-%d %H:%M:%S", # 时间输出的格式
handlers=[logging.NullHandler()],
)
return qcg_logger

154
pkg/core/controller.py Normal file
View File

@@ -0,0 +1,154 @@
from __future__ import annotations
import asyncio
import typing
import traceback
from . import app, entities
from ..pipeline import entities as pipeline_entities
DEFAULT_QUERY_CONCURRENCY = 10
class Controller:
"""总控制器
"""
ap: app.Application
semaphore: asyncio.Semaphore = None
"""请求并发控制信号量"""
def __init__(self, ap: app.Application):
self.ap = ap
self.semaphore = asyncio.Semaphore(DEFAULT_QUERY_CONCURRENCY)
async def consumer(self):
"""事件处理循环
"""
try:
while True:
selected_query: entities.Query = None
# 取请求
async with self.ap.query_pool:
queries: list[entities.Query] = self.ap.query_pool.queries
for query in queries:
session = await self.ap.sess_mgr.get_session(query)
self.ap.logger.debug(f"Checking query {query} session {session}")
if not session.semaphore.locked():
selected_query = query
await session.semaphore.acquire()
break
if selected_query: # 找到了
queries.remove(selected_query)
else: # 没找到 说明:没有请求 或者 所有query对应的session都已达到并发上限
await self.ap.query_pool.condition.wait()
continue
if selected_query:
async def _process_query(selected_query):
async with self.semaphore: # 总并发上限
await self.process_query(selected_query)
async with self.ap.query_pool:
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
# 通知其他协程,有新的请求可以处理了
self.ap.query_pool.condition.notify_all()
asyncio.create_task(_process_query(selected_query))
except Exception as e:
self.ap.logger.error(f"事件处理循环出错: {e}")
traceback.print_exc()
async def _check_output(self, result: pipeline_entities.StageProcessResult):
"""检查输出
"""
if result.user_notice:
await self.ap.im_mgr.send(
result.user_notice
)
if result.debug_notice:
self.ap.logger.debug(result.debug_notice)
if result.console_notice:
self.ap.logger.info(result.console_notice)
async def _execute_from_stage(
self,
stage_index: int,
query: entities.Query,
):
"""从指定阶段开始执行
如何看懂这里为什么这么写?
去问 GPT-4:
Q1: 现在有一个责任链其中有多个stagequery对象在其中传递stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None]
如果返回的是生成器需要挨个生成result检查是否result中是否要求继续如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了就继续生成下一个result
调用后续的stage直到该生成器全部生成完。责任链中可能有多个stage会返回生成器
Q2: 不是这样的你可能理解有误。如果我们责任链上有这些Stage
A B C D E F G
如果所有的stage都返回Result且所有Result都要求继续那么执行顺序是
A B C D E F G
现在假设C返回的是AsyncGenerator那么执行顺序是
A B C D E F G C D E F G C D E F G ...
Q3: 但是如果不止一个stage会返回生成器呢
"""
i = stage_index
while i < len(self.ap.stage_mgr.stage_containers):
stage_container = self.ap.stage_mgr.stage_containers[i]
result = await stage_container.inst.process(query, stage_container.inst_name)
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}")
await self._check_output(result)
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
break
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
query = result.new_query
elif isinstance(result, typing.AsyncGenerator): # 生成器
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen")
async for sub_result in result:
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}")
await self._check_output(sub_result)
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
break
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
query = sub_result.new_query
await self._execute_from_stage(i + 1, query)
break
i += 1
async def process_query(self, query: entities.Query):
"""处理请求
"""
self.ap.logger.debug(f"Processing query {query}")
try:
await self._execute_from_stage(0, query)
except Exception as e:
self.ap.logger.error(f"处理请求时出错 {query}: {e}")
traceback.print_exc()
finally:
self.ap.logger.debug(f"Query {query} processed")
async def run(self):
"""运行控制器
"""
await self.consumer()

41
pkg/core/entities.py Normal file
View File

@@ -0,0 +1,41 @@
from __future__ import annotations
import enum
import typing
import pydantic
import mirai
class LauncherTypes(enum.Enum):
PERSON = 'person'
"""私聊"""
GROUP = 'group'
"""群聊"""
class Query(pydantic.BaseModel):
"""一次请求的信息封装"""
query_id: int
"""请求ID"""
launcher_type: LauncherTypes
"""会话类型"""
launcher_id: int
"""会话ID"""
sender_id: int
"""发送者ID"""
message_event: mirai.MessageEvent
"""事件"""
message_chain: mirai.MessageChain
"""消息链"""
resp_message_chain: typing.Optional[mirai.MessageChain] = None
"""回复消息链"""

52
pkg/core/pool.py Normal file
View File

@@ -0,0 +1,52 @@
from __future__ import annotations
import asyncio
import mirai
from . import entities
class QueryPool:
query_id_counter: int = 0
pool_lock: asyncio.Lock
queries: list[entities.Query]
condition: asyncio.Condition
def __init__(self):
self.query_id_counter = 0
self.pool_lock = asyncio.Lock()
self.queries = []
self.condition = asyncio.Condition(self.pool_lock)
async def add_query(
self,
launcher_type: entities.LauncherTypes,
launcher_id: int,
sender_id: int,
message_event: mirai.MessageEvent,
message_chain: mirai.MessageChain
) -> entities.Query:
async with self.condition:
query = entities.Query(
query_id=self.query_id_counter,
launcher_type=launcher_type,
launcher_id=launcher_id,
sender_id=sender_id,
message_event=message_event,
message_chain=message_chain
)
self.queries.append(query)
self.query_id_counter += 1
self.condition.notify_all()
async def __aenter__(self):
await self.pool_lock.acquire()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.pool_lock.release()

View File

@@ -1,134 +0,0 @@
# 多情景预设值管理
import json
import logging
import os
from ..utils import context
# __current__ = "default"
# """当前默认使用的情景预设的名称
# 由管理员使用`!default <名称>`命令切换
# """
# __prompts_from_files__ = {}
# """从文件中读取的情景预设值"""
# __scenario_from_files__ = {}
class ScenarioMode:
"""情景预设模式抽象类"""
using_prompt_name = "default"
"""新session创建时使用的prompt名称"""
prompts: dict[str, list] = {}
def __init__(self):
logging.debug("prompts: {}".format(self.prompts))
def list(self) -> dict[str, list]:
"""获取所有情景预设的名称及内容"""
return self.prompts
def get_prompt(self, name: str) -> tuple[list, str]:
"""获取指定情景预设的名称及内容"""
for key in self.prompts:
if key.startswith(name):
return self.prompts[key], key
raise Exception("没有找到情景预设: {}".format(name))
def set_using_name(self, name: str) -> str:
"""设置默认情景预设"""
for key in self.prompts:
if key.startswith(name):
self.using_prompt_name = key
return key
raise Exception("没有找到情景预设: {}".format(name))
def get_full_name(self, name: str) -> str:
"""获取完整的情景预设名称"""
for key in self.prompts:
if key.startswith(name):
return key
raise Exception("没有找到情景预设: {}".format(name))
def get_using_name(self) -> str:
"""获取默认情景预设"""
return self.using_prompt_name
class NormalScenarioMode(ScenarioMode):
"""普通情景预设模式"""
def __init__(self):
config = context.get_config_manager().data
# 加载config中的default_prompt值
if type(config['default_prompt']) == str:
self.using_prompt_name = "default"
self.prompts = {"default": [
{
"role": "system",
"content": config['default_prompt']
}
]}
elif type(config['default_prompt']) == dict:
for key in config['default_prompt']:
self.prompts[key] = [
{
"role": "system",
"content": config['default_prompt'][key]
}
]
# 从prompts/目录下的文件中载入
# 遍历文件
for file in os.listdir("prompts"):
with open(os.path.join("prompts", file), encoding="utf-8") as f:
self.prompts[file] = [
{
"role": "system",
"content": f.read()
}
]
class FullScenarioMode(ScenarioMode):
"""完整情景预设模式"""
def __init__(self):
"""从json读取所有"""
# 遍历scenario/目录下的所有文件以文件名为键文件内容中的prompt为值
for file in os.listdir("scenario"):
if file == "default-template.json":
continue
with open(os.path.join("scenario", file), encoding="utf-8") as f:
self.prompts[file] = json.load(f)["prompt"]
super().__init__()
scenario_mode_mapping = {}
"""情景预设模式名称与对象的映射"""
def register_all():
"""注册所有情景预设模式,不使用装饰器,因为装饰器的方式不支持热重载"""
global scenario_mode_mapping
scenario_mode_mapping = {
"normal": NormalScenarioMode(),
"full_scenario": FullScenarioMode()
}
def mode_inst() -> ScenarioMode:
"""获取指定名称的情景预设模式对象"""
config = context.get_config_manager().data
if config['preset_mode'] == "default":
config['preset_mode'] = "normal"
return scenario_mode_mapping[config['preset_mode']]

33
pkg/openai/entities.py Normal file
View File

@@ -0,0 +1,33 @@
from __future__ import annotations
import typing
import enum
import pydantic
class FunctionCall(pydantic.BaseModel):
name: str
arguments: str
class ToolCall(pydantic.BaseModel):
id: str
type: str
function: FunctionCall
class Message(pydantic.BaseModel):
role: str
name: typing.Optional[str] = None
content: typing.Optional[str] = None
function_call: typing.Optional[FunctionCall] = None
tool_calls: typing.Optional[list[ToolCall]] = None
tool_call_id: typing.Optional[str] = None

View File

@@ -1,46 +0,0 @@
# 封装了function calling的一些支持函数
import logging
from ..plugin import host
class ContentFunctionNotFoundError(Exception):
pass
def get_func_schema_list() -> list:
"""从plugin包中的函数结构中获取并处理成受GPT支持的格式"""
if not host.__enable_content_functions__:
return []
schemas = []
for func in host.__callable_functions__:
if func['enabled']:
fun_cp = func.copy()
del fun_cp['enabled']
schemas.append(fun_cp)
return schemas
def get_func(name: str) -> callable:
if name not in host.__function_inst_map__:
raise ContentFunctionNotFoundError("没有找到内容函数: {}".format(name))
return host.__function_inst_map__[name]
def get_func_schema(name: str) -> dict:
for func in host.__callable_functions__:
if func['name'] == name:
return func
raise ContentFunctionNotFoundError("没有找到内容函数: {}".format(name))
def execute_function(name: str, kwargs: dict) -> any:
"""执行函数调用"""
logging.debug("executing function: name='{}', kwargs={}".format(name, kwargs))
func = get_func(name)
return func(**kwargs)

View File

@@ -1,103 +0,0 @@
# 此模块提供了维护api-key的各种功能
import hashlib
import logging
from ..plugin import host as plugin_host
from ..plugin import models as plugin_models
class KeysManager:
api_key = {}
"""所有api-key"""
using_key = ""
"""当前使用的api-key"""
alerted = []
"""已提示过超额的key
记录在此以避免重复提示
"""
exceeded = []
"""已超额的key
供自动切换功能识别
"""
def get_using_key(self):
return self.using_key
def get_using_key_md5(self):
return hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
def __init__(self, api_key):
assert type(api_key) == dict
self.api_key = api_key
# 从usage中删除未加载的api-key的记录
# 不删了也许会运行时添加曾经有记录的api-key
self.auto_switch()
def auto_switch(self) -> tuple[bool, str]:
"""尝试切换api-key
Returns:
是否切换成功, 切换后的api-key的别名
"""
index = 0
for key_name in self.api_key:
if self.api_key[key_name] == self.using_key:
break
index += 1
# 从当前key开始向后轮询
start_index = index
index += 1
if index >= len(self.api_key):
index = 0
while index != start_index:
key_name = list(self.api_key.keys())[index]
if self.api_key[key_name] not in self.exceeded:
self.using_key = self.api_key[key_name]
logging.debug("使用api-key:" + key_name)
# 触发插件事件
args = {
"key_name": key_name,
"key_list": self.api_key.keys()
}
_ = plugin_host.emit(plugin_models.KeySwitched, **args)
return True, key_name
index += 1
if index >= len(self.api_key):
index = 0
self.using_key = list(self.api_key.values())[start_index]
logging.debug("使用api-key:" + list(self.api_key.keys())[start_index])
return False, list(self.api_key.keys())[start_index]
def add(self, key_name, key):
self.api_key[key_name] = key
def set_current_exceeded(self):
"""设置当前使用的api-key使用量超限"""
self.exceeded.append(self.using_key)
def get_key_name(self, api_key):
"""根据api-key获取其别名"""
for key_name in self.api_key:
if self.api_key[key_name] == api_key:
return key_name
return ""

View File

@@ -1,108 +0,0 @@
from __future__ import annotations
import logging
import openai
from openai.types import images_response
from ..openai import keymgr
from ..utils import context
from ..audit import gatherer
from ..openai import modelmgr
from ..openai.api import model as api_model
from ..boot import app
class OpenAIInteract:
"""OpenAI 接口封装
将文字接口和图片接口封装供调用方使用
"""
key_mgr: keymgr.KeysManager = None
audit_mgr: gatherer.DataGatherer = None
default_image_api_params = {
"size": "256x256",
}
client: openai.Client = None
def __init__(self, ap: app.Application):
cfg= ap.cfg_mgr.data
api_key = cfg['openai_config']['api_key']
self.key_mgr = keymgr.KeysManager(api_key)
self.audit_mgr = gatherer.DataGatherer()
# 配置OpenAI proxy
openai.proxies = None # 先重置因为重载后可能需要清除proxy
if "http_proxy" in cfg['openai_config'] and cfg['openai_config']["http_proxy"] is not None:
openai.proxies = {
"http": cfg['openai_config']["http_proxy"],
"https": cfg['openai_config']["http_proxy"]
}
# 配置openai api_base
if "reverse_proxy" in cfg['openai_config'] and cfg['openai_config']["reverse_proxy"] is not None:
logging.debug("设置反向代理: "+cfg['openai_config']['reverse_proxy'])
openai.base_url = cfg['openai_config']["reverse_proxy"]
self.client = openai.Client(
api_key=self.key_mgr.get_using_key(),
base_url=openai.base_url
)
context.set_openai_manager(self)
def request_completion(self, messages: list):
"""请求补全接口回复=
"""
# 选择接口请求类
config = context.get_config_manager().data
request: api_model.RequestBase
model: str = config['completion_api_params']['model']
cp_parmas = config['completion_api_params'].copy()
del cp_parmas['model']
request = modelmgr.select_request_cls(self.client, model, messages, cp_parmas)
# 请求接口
for resp in request:
if resp['usage']['total_tokens'] > 0:
self.audit_mgr.report_text_model_usage(
model,
resp['usage']['total_tokens']
)
yield resp
def request_image(self, prompt) -> images_response.ImagesResponse:
"""请求图片接口回复
Parameters:
prompt (str): 提示语
Returns:
dict: 响应
"""
config = context.get_config_manager().data
params = config['image_api_params']
response = self.client.images.generate(
prompt=prompt,
n=1,
**params
)
self.audit_mgr.report_image_model_usage(params['size'])
return response

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from ...core import entities as core_entities
from .. import entities as llm_entities
from ..session import entities as session_entities
class LLMAPIRequester(metaclass=abc.ABCMeta):
"""LLM API请求器
"""
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def request(
self,
query: core_entities.Query,
conversation: session_entities.Conversation,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求
"""
raise NotImplementedError

View File

@@ -0,0 +1,140 @@
from __future__ import annotations
import asyncio
import typing
import json
import openai
import openai.types.chat.chat_completion as chat_completion
from .. import api
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...session import entities as session_entities
class OpenAIChatCompletion(api.LLMAPIRequester):
client: openai.AsyncClient
async def initialize(self):
self.client = openai.AsyncClient(
api_key="",
base_url=self.ap.cfg_mgr.data["openai_config"]["reverse_proxy"],
timeout=self.ap.cfg_mgr.data["process_message_timeout"],
)
async def _req(
self,
args: dict,
) -> chat_completion.ChatCompletion:
self.ap.logger.debug(f"req chat_completion with args {args}")
return await self.client.chat.completions.create(**args)
async def _make_msg(
self,
chat_completion: chat_completion.ChatCompletion,
) -> llm_entities.Message:
chatcmpl_message = chat_completion.choices[0].message.dict()
message = llm_entities.Message(**chatcmpl_message)
return message
async def _closure(
self,
req_messages: list[dict],
conversation: session_entities.Conversation,
user_text: str = None,
function_ret: str = None,
) -> llm_entities.Message:
self.client.api_key = conversation.use_model.token_mgr.get_token()
args = self.ap.cfg_mgr.data["completion_api_params"].copy()
args["model"] = conversation.use_model.name
tools = await self.ap.tool_mgr.generate_tools_for_openai(conversation)
# tools = [
# {
# "type": "function",
# "function": {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA",
# },
# "unit": {
# "type": "string",
# "enum": ["celsius", "fahrenheit"],
# },
# },
# "required": ["location"],
# },
# },
# }
# ]
if tools:
args["tools"] = tools
# 设置此次请求中的messages
messages = req_messages
args["messages"] = messages
# 发送请求
resp = await self._req(args)
# 处理请求结果
message = await self._make_msg(resp)
return message
async def request(
self, query: core_entities.Query, conversation: session_entities.Conversation
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求"""
pending_tool_calls = []
req_messages = [
m.dict(exclude_none=True) for m in conversation.prompt.messages
] + [m.dict(exclude_none=True) for m in conversation.messages]
# req_messages.append({"role": "user", "content": str(query.message_chain)})
msg = await self._closure(req_messages, conversation)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg.dict(exclude_none=True))
while pending_tool_calls:
for tool_call in pending_tool_calls:
func = tool_call.function
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(
query, func.name, parameters
)
msg = llm_entities.Message(
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
)
yield msg
req_messages.append(msg.dict(exclude_none=True))
# 处理完所有调用,继续请求
msg = await self._closure(req_messages, conversation)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg.dict(exclude_none=True))

View File

@@ -0,0 +1,23 @@
import typing
import pydantic
from . import api
from . import token
class LLMModelInfo(pydantic.BaseModel):
"""模型"""
name: str
provider: str
token_mgr: token.TokenManager
requester: api.LLMAPIRequester
function_call_supported: typing.Optional[bool] = False
class Config:
arbitrary_types_allowed = True

View File

@@ -0,0 +1,41 @@
from __future__ import annotations
from . import entities
from ...core import app
from .apis import chatcmpl
from . import token
class ModelManager:
ap: app.Application
model_list: list[entities.LLMModelInfo]
def __init__(self, ap: app.Application):
self.ap = ap
self.model_list = []
async def initialize(self):
openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap)
await openai_chat_completion.initialize()
openai_token_mgr = token.TokenManager(self.ap, list(self.ap.cfg_mgr.data['openai_config']['api_key'].values()))
self.model_list.append(
entities.LLMModelInfo(
name="gpt-3.5-turbo",
provider="openai",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
function_call_supported=True
)
)
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
"""通过名称获取模型
"""
for model in self.model_list:
if model.name == name:
return model
raise ValueError(f"Model {name} not found")

View File

@@ -0,0 +1,25 @@
from __future__ import annotations
import typing
import pydantic
class TokenManager():
provider: str
tokens: list[str]
using_token_index: typing.Optional[int] = 0
def __init__(self, provider: str, tokens: list[str]):
self.provider = provider
self.tokens = tokens
self.using_token_index = 0
def get_token(self) -> str:
return self.tokens[self.using_token_index]
def next_token(self):
self.using_token_index = (self.using_token_index + 1) % len(self.tokens)

View File

@@ -1,504 +0,0 @@
"""主线使用的会话管理模块
每个人、每个群单独一个sessionsession内部保留了对话的上下文
"""
import logging
import threading
import time
import json
from ..openai import manager as openai_manager
from ..openai import modelmgr as openai_modelmgr
from ..database import manager as database_manager
from ..utils import context as context
from ..plugin import host as plugin_host
from ..plugin import models as plugin_models
# 运行时保存的所有session
sessions = {}
class SessionOfflineStatus:
ON_GOING = 'on_going'
EXPLICITLY_CLOSED = 'explicitly_closed'
# 从数据加载session
def load_sessions():
"""从数据库加载sessions"""
global sessions
db_inst = context.get_database_manager()
session_data = db_inst.load_valid_sessions()
for session_name in session_data:
logging.debug('加载session: {}'.format(session_name))
temp_session = Session(session_name)
temp_session.name = session_name
temp_session.create_timestamp = session_data[session_name]['create_timestamp']
temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp']
temp_session.prompt = json.loads(session_data[session_name]['prompt'])
temp_session.token_counts = json.loads(session_data[session_name]['token_counts'])
temp_session.default_prompt = json.loads(session_data[session_name]['default_prompt']) if \
session_data[session_name]['default_prompt'] else []
sessions[session_name] = temp_session
# 获取指定名称的session如果不存在则创建一个新的
def get_session(session_name: str) -> 'Session':
global sessions
if session_name not in sessions:
sessions[session_name] = Session(session_name)
return sessions[session_name]
def dump_session(session_name: str):
global sessions
if session_name in sessions:
assert isinstance(sessions[session_name], Session)
sessions[session_name].persistence()
del sessions[session_name]
# 通用的OpenAI API交互session
# session内部保留了对话的上下文
# 收到用户消息后将上下文提交给OpenAI API生成回复
class Session:
name = ''
prompt = []
"""使用list来保存会话中的回合"""
default_prompt = []
"""本session的默认prompt"""
create_timestamp = 0
"""会话创建时间"""
last_interact_timestamp = 0
"""上次交互(产生回复)时间"""
just_switched_to_exist_session = False
response_lock = None
# 加锁
def acquire_response_lock(self):
logging.debug('{},lock acquire,{}'.format(self.name, self.response_lock))
self.response_lock.acquire()
logging.debug('{},lock acquire successfully,{}'.format(self.name, self.response_lock))
# 释放锁
def release_response_lock(self):
if self.response_lock.locked():
logging.debug('{},lock release,{}'.format(self.name, self.response_lock))
self.response_lock.release()
logging.debug('{},lock release successfully,{}'.format(self.name, self.response_lock))
# 从配置文件获取会话预设信息
def get_default_prompt(self, use_default: str = None):
import pkg.openai.dprompt as dprompt
if use_default is None:
use_default = dprompt.mode_inst().get_using_name()
current_default_prompt, _ = dprompt.mode_inst().get_prompt(use_default)
return current_default_prompt
def __init__(self, name: str):
self.name = name
self.create_timestamp = int(time.time())
self.last_interact_timestamp = int(time.time())
self.prompt = []
self.token_counts = []
self.schedule()
self.response_lock = threading.Lock()
self.default_prompt = self.get_default_prompt()
logging.debug("prompt is: {}".format(self.default_prompt))
# 设定检查session最后一次对话是否超过过期时间的计时器
def schedule(self):
threading.Thread(target=self.expire_check_timer_loop, args=(self.create_timestamp,)).start()
# 检查session是否已经过期
def expire_check_timer_loop(self, create_timestamp: int):
global sessions
while True:
time.sleep(60)
# 不是此session已更换退出
if self.create_timestamp != create_timestamp or self not in sessions.values():
return
config = context.get_config_manager().data
if int(time.time()) - self.last_interact_timestamp > config['session_expire_time']:
logging.info('session {} 已过期'.format(self.name))
# 触发插件事件
args = {
'session_name': self.name,
'session': self,
'session_expire_time': config['session_expire_time']
}
event = plugin_host.emit(plugin_models.SessionExpired, **args)
if event.is_prevented_default():
return
self.reset(expired=True, schedule_new=False)
# 删除此session
del sessions[self.name]
return
# 请求回复
# 这个函数是阻塞的
def query(self, text: str=None) -> tuple[str, str, list[str]]:
"""向session中添加一条消息返回接口回复
Args:
text (str): 用户消息
Returns:
tuple[str, str]: (接口回复, finish_reason, 已调用的函数列表)
"""
self.last_interact_timestamp = int(time.time())
# 触发插件事件
if not self.prompt:
args = {
'session_name': self.name,
'session': self,
'default_prompt': self.default_prompt,
}
event = plugin_host.emit(plugin_models.SessionFirstMessageReceived, **args)
if event.is_prevented_default():
return None, None, None
config = context.get_config_manager().data
max_length = config['prompt_submit_length']
local_default_prompt = self.default_prompt.copy()
local_prompt = self.prompt.copy()
# 触发PromptPreProcessing事件
args = {
'session_name': self.name,
'default_prompt': self.default_prompt,
'prompt': self.prompt,
'text_message': text,
}
event = plugin_host.emit(plugin_models.PromptPreProcessing, **args)
if event.get_return_value('default_prompt') is not None:
local_default_prompt = event.get_return_value('default_prompt')
if event.get_return_value('prompt') is not None:
local_prompt = event.get_return_value('prompt')
if event.get_return_value('text_message') is not None:
text = event.get_return_value('text_message')
# 裁剪messages到合适长度
prompts, _ = self.cut_out(text, max_length, local_default_prompt, local_prompt)
res_text = ""
pending_msgs = []
total_tokens = 0
finish_reason: str = ""
funcs = []
trace_func_calls = config['trace_function_calls']
botmgr = context.get_qqbot_manager()
session_name_spt: list[str] = self.name.split("_")
pending_res_text = ""
start_time = time.time()
# TODO 对不起,我知道这样非常非常屎山,但我之后会重构的
for resp in context.get_openai_manager().request_completion(prompts):
if pending_res_text != "":
botmgr.adapter.send_message(
session_name_spt[0],
session_name_spt[1],
pending_res_text
)
pending_res_text = ""
finish_reason = resp['choices'][0]['finish_reason']
if resp['choices'][0]['message']['role'] == "assistant" and resp['choices'][0]['message']['content'] != None: # 包含纯文本响应
if not trace_func_calls:
res_text += resp['choices'][0]['message']['content']
else:
res_text = resp['choices'][0]['message']['content']
pending_res_text = resp['choices'][0]['message']['content']
total_tokens += resp['usage']['total_tokens']
msg = {
"role": "assistant",
"content": resp['choices'][0]['message']['content']
}
if 'function_call' in resp['choices'][0]['message']:
msg['function_call'] = json.dumps(resp['choices'][0]['message']['function_call'])
pending_msgs.append(msg)
if resp['choices'][0]['message']['type'] == 'function_call':
# self.prompt.append(
# {
# "role": "assistant",
# "content": "function call: "+json.dumps(resp['choices'][0]['message']['function_call'])
# }
# )
if trace_func_calls:
botmgr.adapter.send_message(
session_name_spt[0],
session_name_spt[1],
"调用函数 "+resp['choices'][0]['message']['function_call']['name'] + "..."
)
total_tokens += resp['usage']['total_tokens']
elif resp['choices'][0]['message']['type'] == 'function_return':
# self.prompt.append(
# {
# "role": "function",
# "name": resp['choices'][0]['message']['function_name'],
# "content": json.dumps(resp['choices'][0]['message']['content'])
# }
# )
# total_tokens += resp['usage']['total_tokens']
funcs.append(
resp['choices'][0]['message']['function_name']
)
pass
# 向API请求补全
# message, total_token = pkg.utils.context.get_openai_manager().request_completion(
# prompts,
# )
# 成功获取,处理回复
# res_test = message
res_ans = res_text.strip()
# 将此次对话的双方内容加入到prompt中
# self.prompt.append({'role': 'user', 'content': text})
# self.prompt.append({'role': 'assistant', 'content': res_ans})
if text:
self.prompt.append({'role': 'user', 'content': text})
# 添加pending_msgs
self.prompt += pending_msgs
# 向token_counts中添加本回合的token数量
# self.token_counts.append(total_tokens-total_token_before_query)
# logging.debug("本回合使用token: {}, session counts: {}".format(total_tokens-total_token_before_query, self.token_counts))
if self.just_switched_to_exist_session:
self.just_switched_to_exist_session = False
self.set_ongoing()
# 上报使用量数据
session_type = session_name_spt[0]
session_id = session_name_spt[1]
ability_provider = "QChatGPT.Text"
usage = total_tokens
model_name = context.get_config_manager().data['completion_api_params']['model']
response_seconds = int(time.time() - start_time)
retry_times = -1 # 暂不记录
context.get_center_v2_api().usage.post_query_record(
session_type=session_type,
session_id=session_id,
query_ability_provider=ability_provider,
usage=usage,
model_name=model_name,
response_seconds=response_seconds,
retry_times=retry_times
)
return res_ans if res_ans[0] != '\n' else res_ans[1:], finish_reason, funcs
# 删除上一回合并返回上一回合的问题
def undo(self) -> str:
self.last_interact_timestamp = int(time.time())
# 删除最后两个消息
if len(self.prompt) < 2:
raise Exception('之前无对话,无法撤销')
question = self.prompt[-2]['content']
self.prompt = self.prompt[:-2]
self.token_counts = self.token_counts[:-1]
# 返回上一回合的问题
return question
# 构建对话体
def cut_out(self, msg: str, max_tokens: int, default_prompt: list, prompt: list) -> tuple[list, list]:
"""将现有prompt进行切割处理使得新的prompt长度不超过max_tokens
:return: (新的prompt, 新的token_counts)
"""
# 最终由三个部分组成
# - default_prompt 情景预设固定值
# - changable_prompts 可变部分, 此会话中的历史对话回合
# - current_question 当前问题
# 包装目前的对话回合内容
changable_prompts = []
use_model = context.get_config_manager().data['completion_api_params']['model']
ptr = len(prompt) - 1
# 直接从后向前扫描拼接,不管是否是整回合
while ptr >= 0:
if openai_modelmgr.count_tokens(prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens:
break
changable_prompts.insert(0, prompt[ptr])
ptr -= 1
# 将default_prompt和changable_prompts合并
result_prompt = default_prompt + changable_prompts
# 添加当前问题
if msg:
result_prompt.append(
{
'role': 'user',
'content': msg
}
)
logging.debug("cut_out: {}".format(json.dumps(result_prompt, ensure_ascii=False, indent=4)))
return result_prompt, openai_modelmgr.count_tokens(changable_prompts, use_model)
# 持久化session
def persistence(self):
if self.prompt == self.get_default_prompt():
return
db_inst = context.get_database_manager()
name_spt = self.name.split('_')
subject_type = name_spt[0]
subject_number = int(name_spt[1])
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
json.dumps(self.prompt), json.dumps(self.default_prompt), json.dumps(self.token_counts))
# 重置session
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None, persist: bool = False):
if self.prompt:
self.persistence()
if explicit:
# 触发插件事件
args = {
'session_name': self.name,
'session': self
}
# 此事件不支持阻止默认行为
_ = plugin_host.emit(plugin_models.SessionExplicitReset, **args)
context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
if expired:
context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
if not persist: # 不要求保持default prompt
self.default_prompt = self.get_default_prompt(use_prompt)
self.prompt = []
self.token_counts = []
self.create_timestamp = int(time.time())
self.last_interact_timestamp = int(time.time())
self.just_switched_to_exist_session = False
# self.response_lock = threading.Lock()
if schedule_new:
self.schedule()
# 将本session的数据库状态设置为on_going
def set_ongoing(self):
context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp)
# 切换到上一个session
def last_session(self):
last_one = context.get_database_manager().last_session(self.name, self.last_interact_timestamp)
if last_one is None:
return None
else:
self.persistence()
self.create_timestamp = last_one['create_timestamp']
self.last_interact_timestamp = last_one['last_interact_timestamp']
self.prompt = json.loads(last_one['prompt'])
self.token_counts = json.loads(last_one['token_counts'])
self.default_prompt = json.loads(last_one['default_prompt']) if last_one['default_prompt'] else []
self.just_switched_to_exist_session = True
return self
# 切换到下一个session
def next_session(self):
next_one = context.get_database_manager().next_session(self.name, self.last_interact_timestamp)
if next_one is None:
return None
else:
self.persistence()
self.create_timestamp = next_one['create_timestamp']
self.last_interact_timestamp = next_one['last_interact_timestamp']
self.prompt = json.loads(next_one['prompt'])
self.token_counts = json.loads(next_one['token_counts'])
self.default_prompt = json.loads(next_one['default_prompt']) if next_one['default_prompt'] else []
self.just_switched_to_exist_session = True
return self
def list_history(self, capacity: int = 10, page: int = 0):
return context.get_database_manager().list_history(self.name, capacity, page)
def delete_history(self, index: int) -> bool:
return context.get_database_manager().delete_history(self.name, index)
def delete_all_history(self) -> bool:
return context.get_database_manager().delete_all_history(self.name)
def draw_image(self, prompt: str):
return context.get_openai_manager().request_image(prompt)

View File

@@ -0,0 +1,53 @@
from __future__ import annotations
import datetime
import asyncio
import typing
import pydantic
from ..sysprompt import entities as sysprompt_entities
from .. import entities as llm_entities
from ..requester import entities
from ...core import entities as core_entities
from ..tools import entities as tools_entities
class Conversation(pydantic.BaseModel):
"""对话"""
prompt: sysprompt_entities.Prompt
messages: list[llm_entities.Message]
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
use_model: entities.LLMModelInfo
use_funcs: typing.Optional[list[tools_entities.LLMFunction]]
class Session(pydantic.BaseModel):
"""会话"""
launcher_type: core_entities.LauncherTypes
launcher_id: int
sender_id: typing.Optional[int] = 0
use_prompt_name: typing.Optional[str] = 'default'
using_conversation: typing.Optional[Conversation] = None
conversations: typing.Optional[list[Conversation]] = []
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
semaphore: typing.Optional[asyncio.Semaphore] = None
class Config:
arbitrary_types_allowed = True

View File

@@ -0,0 +1,51 @@
from __future__ import annotations
import asyncio
from ...core import app, entities as core_entities
from . import entities
class SessionManager:
ap: app.Application
session_list: list[entities.Session]
def __init__(self, ap: app.Application):
self.ap = ap
self.session_list = []
async def initialize(self):
pass
async def get_session(self, query: core_entities.Query) -> entities.Session:
"""获取会话
"""
for session in self.session_list:
if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id:
return session
session = entities.Session(
launcher_type=query.launcher_type,
launcher_id=query.launcher_id,
semaphore=asyncio.Semaphore(1) if self.ap.cfg_mgr.data['wait_last_done'] else asyncio.Semaphore(10000),
)
self.session_list.append(session)
return session
async def get_conversation(self, session: entities.Session) -> entities.Conversation:
if not session.conversations:
session.conversations = []
if session.using_conversation is None:
conversation = entities.Conversation(
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
messages=[],
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.cfg_mgr.data['completion_api_params']['model']),
use_funcs=await self.ap.tool_mgr.get_all_functions(),
)
session.conversations.append(conversation)
session.using_conversation = conversation
return session.using_conversation

View File

@@ -0,0 +1,14 @@
from __future__ import annotations
import typing
import pydantic
from ...openai import entities
class Prompt(pydantic.BaseModel):
"""供AI使用的Prompt"""
name: str
messages: list[entities.Message]

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
import abc
from ...core import app
from . import entities
class PromptLoader(metaclass=abc.ABCMeta):
"""Prompt加载器抽象类
"""
ap: app.Application
prompts: list[entities.Prompt]
def __init__(self, ap: app.Application):
self.ap = ap
self.prompts = []
async def initialize(self):
pass
@abc.abstractmethod
async def load(self):
"""加载Prompt
"""
raise NotImplementedError
def get_prompts(self) -> list[entities.Prompt]:
"""获取Prompt列表
"""
return self.prompts

View File

@@ -0,0 +1,38 @@
from __future__ import annotations
import json
import os
from .. import loader
from .. import entities
from ....openai import entities as llm_entities
class ScenarioPromptLoader(loader.PromptLoader):
"""加载scenario目录下的json"""
async def load(self):
"""加载Prompt
"""
for file in os.listdir("scenarios"):
with open("scenarios/{}".format(file), "r", encoding="utf-8") as f:
file_str = f.read()
file_name = file.split(".")[0]
file_json = json.loads(file_str)
messages = []
for msg in file_json["prompt"]:
role = 'system'
if "role" in msg:
role = msg['role']
messages.append(
llm_entities.Message(
role=role,
content=msg['content'],
)
)
prompt = entities.Prompt(
name=file_name,
messages=messages
)
self.prompts.append(prompt)

View File

@@ -0,0 +1,42 @@
from __future__ import annotations
import os
from .. import loader
from .. import entities
from ....openai import entities as llm_entities
class SingleSystemPromptLoader(loader.PromptLoader):
"""配置文件中的单条system prompt的prompt加载器
"""
async def load(self):
"""加载Prompt
"""
for name, cnt in self.ap.cfg_mgr.data['default_prompt'].items():
prompt = entities.Prompt(
name=name,
messages=[
llm_entities.Message(
role='system',
content=cnt
)
]
)
self.prompts.append(prompt)
for file in os.listdir("prompts"):
with open("prompts/{}".format(file), "r", encoding="utf-8") as f:
file_str = f.read()
file_name = file.split(".")[0]
prompt = entities.Prompt(
name=file_name,
messages=[
llm_entities.Message(
role='system',
content=file_str
)
]
)
self.prompts.append(prompt)

View File

@@ -0,0 +1,50 @@
from __future__ import annotations
from ...core import app
from . import loader
from .loaders import single, scenario
class PromptManager:
ap: app.Application
loader_inst: loader.PromptLoader
default_prompt: str = 'default'
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
loader_map = {
"normal": single.SingleSystemPromptLoader,
"full_scenario": scenario.ScenarioPromptLoader
}
loader_cls = loader_map[self.ap.cfg_mgr.data['preset_mode']]
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)
await self.loader_inst.initialize()
await self.loader_inst.load()
def get_all_prompts(self) -> list[loader.entities.Prompt]:
"""获取所有Prompt
"""
return self.loader_inst.get_prompts()
async def get_prompt(self, name: str) -> loader.entities.Prompt:
"""获取Prompt
"""
for prompt in self.get_all_prompts():
if prompt.name == name:
return prompt
async def get_prompt_by_prefix(self, prefix: str) -> loader.entities.Prompt:
"""通过前缀获取Prompt
"""
for prompt in self.get_all_prompts():
if prompt.name.startswith(prefix):
return prompt

View File

@@ -0,0 +1,35 @@
from __future__ import annotations
import abc
import typing
import asyncio
import pydantic
class LLMFunction(pydantic.BaseModel):
"""函数"""
name: str
"""函数名"""
human_desc: str
description: str
"""给LLM识别的函数描述"""
enable: typing.Optional[bool] = True
parameters: dict
func: typing.Callable
"""供调用的python异步方法
此异步方法第一个参数接收当前请求的query对象可以从其中取出session等信息。
query参数不在parameters中但在调用时会自动传入。
但在当前版本中,插件提供的内容函数都是同步的,且均为请求无关的,故在此版本的实现(以及考虑了向后兼容性的版本)中,
对插件的内容函数进行封装并存到这里来。
"""
class Config:
arbitrary_types_allowed = True

View File

@@ -0,0 +1,99 @@
from __future__ import annotations
import typing
from ...core import app, entities as core_entities
from . import entities
from ..session import entities as session_entities
class ToolManager:
"""LLM工具管理器
"""
ap: app.Application
all_functions: list[entities.LLMFunction]
def __init__(self, ap: app.Application):
self.ap = ap
self.all_functions = []
async def initialize(self):
pass
def register_legacy_function(self, name: str, description: str, parameters: dict, func: callable):
"""注册函数
"""
async def wrapper(query, **kwargs):
return func(**kwargs)
function = entities.LLMFunction(
name=name,
description=description,
human_desc='',
enable=True,
parameters=parameters,
func=wrapper
)
self.all_functions.append(function)
async def register_function(self, function: entities.LLMFunction):
"""添加函数
"""
self.all_functions.append(function)
async def get_function(self, name: str) -> entities.LLMFunction:
"""获取函数
"""
for function in self.all_functions:
if function.name == name:
return function
return None
async def get_all_functions(self) -> list[entities.LLMFunction]:
"""获取所有函数
"""
return self.all_functions
async def generate_tools_for_openai(self, conversation: session_entities.Conversation) -> str:
"""生成函数列表
"""
tools = []
for function in conversation.use_funcs:
if function.enable:
function_schema = {
"type": "function",
"function": {
"name": function.name,
"description": function.description,
"parameters": function.parameters
}
}
tools.append(function_schema)
return tools
async def execute_func_call(
self,
query: core_entities.Query,
name: str,
parameters: dict
) -> typing.Any:
"""执行函数调用
"""
# return "i'm not sure for the args "+str(parameters)
function = await self.get_function(name)
if function is None:
return None
parameters = parameters.copy()
parameters = {
"query": query,
**parameters
}
return await function.func(**parameters)

View File

@@ -1,70 +1,76 @@
# 处理对会话的禁用配置
# 过去的 banlist
from __future__ import annotations
import re
from ...boot import app
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
class SessionBanManager:
ap: app.Application = None
@stage.stage_class('BanSessionCheckStage')
class BanSessionCheckStage(stage.PipelineStage):
banlist_mgr: cfg_mgr.ConfigManager
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
self.banlist_mgr = await cfg_mgr.load_python_module_config(
"banlist.py",
"res/templates/banlist-template.py"
)
async def is_banned(
self, launcher_type: str, launcher_id: int, sender_id: int
) -> bool:
async def process(
self,
query: core_entities.Query,
stage_inst_name: str
) -> entities.StageProcessResult:
if not self.banlist_mgr.data['enable']:
return False
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
result = False
if launcher_type == 'group':
if query.launcher_type == 'group':
if not self.banlist_mgr.data['enable_group']: # 未启用群聊响应
result = True
# 检查是否显式声明发起人QQ要被person忽略
elif sender_id in self.banlist_mgr.data['person']:
elif query.sender_id in self.banlist_mgr.data['person']:
result = True
else:
for group_rule in self.banlist_mgr.data['group']:
if type(group_rule) == int:
if group_rule == launcher_id:
if group_rule == query.launcher_id:
result = True
elif type(group_rule) == str:
if group_rule.startswith('!'):
reg_str = group_rule[1:]
if re.match(reg_str, str(launcher_id)):
if re.match(reg_str, str(query.launcher_id)):
result = False
break
else:
if re.match(group_rule, str(launcher_id)):
if re.match(group_rule, str(query.launcher_id)):
result = True
elif launcher_type == 'person':
elif query.launcher_type == 'person':
if not self.banlist_mgr.data['enable_private']:
result = True
else:
for person_rule in self.banlist_mgr.data['person']:
if type(person_rule) == int:
if person_rule == launcher_id:
if person_rule == query.launcher_id:
result = True
elif type(person_rule) == str:
if person_rule.startswith('!'):
reg_str = person_rule[1:]
if re.match(reg_str, str(launcher_id)):
if re.match(reg_str, str(query.launcher_id)):
result = False
break
else:
if re.match(person_rule, str(launcher_id)):
if re.match(person_rule, str(query.launcher_id)):
result = True
return result
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE if not result else entities.ResultType.INTERRUPT,
new_query=query,
debug_notice=f'根据禁用列表忽略消息: {query.launcher_type}_{query.launcher_id}' if result else ''
)

View File

@@ -0,0 +1,128 @@
from __future__ import annotations
import mirai
from ...core import app
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from . import filter, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine
@stage.stage_class('PostContentFilterStage')
@stage.stage_class('PreContentFilterStage')
class ContentFilterStage(stage.PipelineStage):
filter_chain: list[filter.ContentFilter]
def __init__(self, ap: app.Application):
self.filter_chain = []
super().__init__(ap)
async def initialize(self):
self.filter_chain.append(cntignore.ContentIgnore(self.ap))
if self.ap.cfg_mgr.data['sensitive_word_filter']:
self.filter_chain.append(banwords.BanWordFilter(self.ap))
if self.ap.cfg_mgr.data['baidu_check']:
self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap))
for filter in self.filter_chain:
await filter.initialize()
async def _pre_process(
self,
message: str,
query: core_entities.Query,
) -> entities.StageProcessResult:
"""请求llm前处理消息
只要有一个不通过就不放行,只放行 PASS 的消息
"""
if not self.ap.cfg_mgr.data['income_msg_check']:
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
for filter in self.filter_chain:
if filter_entities.EnableStage.PRE in filter.enable_stages:
result = await filter.process(message)
if result.level in [
filter_entities.ResultLevel.BLOCK,
filter_entities.ResultLevel.MASKED
]:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice=result.user_notice,
console_notice=result.console_notice
)
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
message = result.replacement
query.message_chain = mirai.MessageChain(
mirai.Plain(message)
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
async def _post_process(
self,
message: str,
query: core_entities.Query,
) -> entities.StageProcessResult:
"""请求llm后处理响应
只要是 PASS 或者 MASKED 的就通过此 filter将其 replacement 设置为message进入下一个 filter
"""
for filter in self.filter_chain:
if filter_entities.EnableStage.POST in filter.enable_stages:
result = await filter.process(message)
if result.level == filter_entities.ResultLevel.BLOCK:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice=result.user_notice,
console_notice=result.console_notice
)
elif result.level in [
filter_entities.ResultLevel.PASS,
filter_entities.ResultLevel.MASKED
]:
message = result.replacement
query.message_chain = mirai.MessageChain(
mirai.Plain(message)
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
async def process(
self,
query: core_entities.Query,
stage_inst_name: str
) -> entities.StageProcessResult:
"""处理
"""
if stage_inst_name == 'PreContentFilterStage':
return await self._pre_process(
str(query.message_chain).strip(),
query
)
elif stage_inst_name == 'PostContentFilterStage':
return await self._post_process(
str(query.message_chain).strip(),
query
)
else:
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
import abc
from ...boot import app
from ...core import app
from . import entities

38
pkg/pipeline/entities.py Normal file
View File

@@ -0,0 +1,38 @@
from __future__ import annotations
import enum
import typing
import pydantic
import mirai
import mirai.models.message as mirai_message
from ..core import entities
class ResultType(enum.Enum):
CONTINUE = enum.auto()
"""继续流水线"""
INTERRUPT = enum.auto()
"""中断流水线"""
class StageProcessResult(pydantic.BaseModel):
result_type: ResultType
new_query: entities.Query
user_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = []
"""只要设置了就会发送给用户"""
admin_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = []
"""只要设置了就会发送给管理员"""
console_notice: typing.Optional[str] = ''
"""只要设置了就会输出到控制台"""
debug_notice: typing.Optional[str] = ''

View File

View File

@@ -3,22 +3,21 @@ import os
import traceback
from PIL import Image, ImageDraw, ImageFont
from mirai.models.message import MessageComponent, Plain
from mirai.models.message import MessageComponent, Plain, MessageChain
from ...boot import app
from ...core import app
from . import strategy
from .strategies import image, forward
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
class LongTextProcessor:
ap: app.Application
@stage.stage_class("LongTextProcessStage")
class LongTextProcessStage(stage.PipelineStage):
strategy_impl: strategy.LongTextStrategy
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
config = self.ap.cfg_mgr.data
if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image':
@@ -48,9 +47,11 @@ class LongTextProcessor:
elif self.ap.cfg_mgr.data['blob_message_strategy'] == 'forward':
self.strategy_impl = forward.ForwardComponentStrategy(self.ap)
await self.strategy_impl.initialize()
async def check_and_process(self, message: str) -> list[MessageComponent]:
if len(message) > self.ap.cfg_mgr.data['blob_message_threshold']:
return await self.strategy_impl.process(message)
else:
return [Plain(message)]
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if len(str(query.resp_message_chain)) > self.ap.cfg_mgr.data['blob_message_threshold']:
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain)))
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@@ -5,7 +5,7 @@ import typing
import mirai
from mirai.models.message import MessageComponent
from ...boot import app
from ...core import app
class LongTextStrategy(metaclass=abc.ABCMeta):

View File

View File

@@ -0,0 +1,25 @@
from __future__ import annotations
import abc
from ...core import app
from ...core import entities as core_entities
from .. import entities
class MessageHandler(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def handle(
self,
query: core_entities.Query,
) -> entities.StageProcessResult:
raise NotImplementedError

View File

@@ -0,0 +1,44 @@
from __future__ import annotations
import typing
import mirai
from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....openai import entities as llm_entities
class ChatMessageHandler(handler.MessageHandler):
async def handle(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
# 取session
# 取conversation
# 调API
# 生成器
session = await self.ap.sess_mgr.get_session(query)
conversation = await self.ap.sess_mgr.get_conversation(session)
conversation.messages.append(
llm_entities.Message(
role="user",
content=str(query.message_chain)
)
)
async for result in conversation.use_model.requester.request(query, conversation):
conversation.messages.append(result)
query.resp_message_chain = mirai.MessageChain([mirai.Plain(str(result))])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@@ -0,0 +1,50 @@
from __future__ import annotations
import typing
import mirai
from .. import handler
from ... import entities
from ....core import entities as core_entities
class CommandHandler(handler.MessageHandler):
async def handle(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
session = await self.ap.sess_mgr.get_session(query)
command_text = str(query.message_chain).strip()[1:]
async for ret in self.ap.cmd_mgr.execute(
command_text=command_text,
query=query,
session=session
):
if ret.error is not None:
query.resp_message_chain = mirai.MessageChain([
mirai.Plain(str(ret.error))
])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
elif ret.text is not None:
query.resp_message_chain = mirai.MessageChain([
mirai.Plain(ret.text)
])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)

View File

@@ -0,0 +1,38 @@
from __future__ import annotations
from ...core import app, entities as core_entities
from . import handler
from .handlers import chat, command
from .. import entities
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@stage.stage_class("MessageProcessor")
class Processor(stage.PipelineStage):
cmd_handler: handler.MessageHandler
chat_handler: handler.MessageHandler
async def initialize(self):
self.cmd_handler = command.CommandHandler(self.ap)
self.chat_handler = chat.ChatMessageHandler(self.ap)
await self.cmd_handler.initialize()
await self.chat_handler.initialize()
async def process(
self,
query: core_entities.Query,
stage_inst_name: str,
) -> entities.StageProcessResult:
"""处理
"""
message_text = str(query.message_chain).strip()
if message_text.startswith('!') or message_text.startswith(''):
return self.cmd_handler.handle(query)
else:
return self.chat_handler.handle(query)

View File

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
import mirai
from ...core import app
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@stage.stage_class("SendResponseBackStage")
class SendResponseBackStage(stage.PipelineStage):
"""发送响应消息
"""
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理
"""
await self.ap.im_mgr.send(
query.message_event,
query.resp_message_chain
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

View File

@@ -0,0 +1,62 @@
from __future__ import annotations
import mirai
from ...core import app
from . import entities as rule_entities, rule
from .rules import atbot, prefix, regexp, random
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@stage.stage_class("GroupRespondRuleCheckStage")
class GroupRespondRuleCheckStage(stage.PipelineStage):
"""群组响应规则检查器
"""
rule_matchers: list[rule.GroupRespondRule]
async def initialize(self):
"""初始化检查器
"""
self.rule_matchers = [
atbot.AtBotRule(self.ap),
prefix.PrefixRule(self.ap),
regexp.RegExpRule(self.ap),
random.RandomRespRule(self.ap),
]
for rule_matcher in self.rule_matchers:
await rule_matcher.initialize()
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if query.launcher_type != 'group':
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
rules = self.ap.cfg_mgr.data['response_rules']
use_rule = rules['default']
if str(query.launcher_id) in use_rule:
use_rule = use_rule[str(query.launcher_id)]
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule)
if res.matching:
query.message_chain = res.replacement
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query,
)
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)

View File

@@ -3,7 +3,7 @@ import abc
import mirai
from ...boot import app
from ...core import app
from . import entities

View File

47
pkg/pipeline/stage.py Normal file
View File

@@ -0,0 +1,47 @@
from __future__ import annotations
import abc
import typing
from ..core import app, entities as core_entities
from . import entities
_stage_classes: dict[str, PipelineStage] = {}
def stage_class(name: str):
def decorator(cls):
_stage_classes[name] = cls
return cls
return decorator
class PipelineStage(metaclass=abc.ABCMeta):
"""流水线阶段
"""
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
"""初始化
"""
pass
@abc.abstractmethod
async def process(
self,
query: core_entities.Query,
stage_inst_name: str,
) -> typing.Union[
entities.StageProcessResult,
typing.AsyncGenerator[entities.StageProcessResult, None],
]:
"""处理
"""
raise NotImplementedError

63
pkg/pipeline/stagemgr.py Normal file
View File

@@ -0,0 +1,63 @@
from __future__ import annotations
import pydantic
from ..core import app
from . import stage
from .resprule import resprule
from .bansess import bansess
from .cntfilter import cntfilter
from .process import process
from .longtext import longtext
from .respback import respback
stage_order = [
"GroupRespondRuleCheckStage",
"BanSessionCheckStage",
"PreContentFilterStage",
"MessageProcessor",
"PostContentFilterStage",
"LongTextProcessStage",
"SendResponseBackStage",
]
class StageInstContainer():
"""阶段实例容器
"""
inst_name: str
inst: stage.PipelineStage
def __init__(self, inst_name: str, inst: stage.PipelineStage):
self.inst_name = inst_name
self.inst = inst
class StageManager:
ap: app.Application
stage_containers: list[StageInstContainer]
def __init__(self, ap: app.Application):
self.ap = ap
self.stage_containers = []
async def initialize(self):
"""初始化
"""
for name, cls in stage._stage_classes.items():
self.stage_containers.append(StageInstContainer(
inst_name=name,
inst=cls(self.ap)
))
for stage_containers in self.stage_containers:
await stage_containers.inst.initialize()
# 按照 stage_order 排序
self.stage_containers.sort(key=lambda x: stage_order.index(x.inst_name))

Some files were not shown because too many files have changed in this diff Show More