mirror of
https://github.com/langbot-app/LangBot.git
synced 2026-06-02 12:05:54 +00:00
Merge pull request #673 from RockChinQ/refactor/asyncio/control-flow
Refactor: 请求处理控制流
This commit is contained in:
58
.github/workflows/update-cmdpriv-template.yml
vendored
58
.github/workflows/update-cmdpriv-template.yml
vendored
@@ -1,58 +0,0 @@
|
||||
name: Update cmdpriv-template
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'pkg/qqbot/cmds/**'
|
||||
pull_request:
|
||||
types: [closed]
|
||||
paths:
|
||||
- 'pkg/qqbot/cmds/**'
|
||||
|
||||
jobs:
|
||||
update-cmdpriv-template:
|
||||
if: github.event.pull_request.merged == true || github.event_name == 'push'
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.10.13
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade yiri-mirai-rc openai>=1.0.0 colorlog func_timeout dulwich Pillow CallingGPT tiktoken
|
||||
python -m pip install -U openai>=1.0.0
|
||||
|
||||
- name: Copy Scripts
|
||||
run: |
|
||||
cp res/scripts/generate_cmdpriv_template.py .
|
||||
|
||||
- name: Generate Files
|
||||
run: |
|
||||
python main.py
|
||||
|
||||
- name: Run generate_cmdpriv_template.py
|
||||
run: python3 generate_cmdpriv_template.py
|
||||
|
||||
- name: Check for changes in cmdpriv-template.json
|
||||
id: check_changes
|
||||
run: |
|
||||
if git diff --name-only | grep -q "res/templates/cmdpriv-template.json"; then
|
||||
echo "::set-output name=changes_detected::true"
|
||||
else
|
||||
echo "::set-output name=changes_detected::false"
|
||||
fi
|
||||
|
||||
- name: Commit changes to cmdpriv-template.json
|
||||
if: steps.check_changes.outputs.changes_detected == 'true'
|
||||
run: |
|
||||
git config --global user.name "GitHub Actions Bot"
|
||||
git config --global user.email "<github-actions@github.com>"
|
||||
git add res/templates/cmdpriv-template.json
|
||||
git commit -m "Update cmdpriv-template.json"
|
||||
git push
|
||||
@@ -1,38 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from ..qqbot import manager as qqbot_mgr
|
||||
from ..openai import manager as openai_mgr
|
||||
from ..config import manager as config_mgr
|
||||
from ..database import manager as database_mgr
|
||||
from ..utils.center import v2 as center_mgr
|
||||
from ..plugin import host as plugin_host
|
||||
|
||||
|
||||
class Application:
|
||||
im_mgr: qqbot_mgr.QQBotManager = None
|
||||
|
||||
llm_mgr: openai_mgr.OpenAIInteract = None
|
||||
|
||||
cfg_mgr: config_mgr.ConfigManager = None
|
||||
|
||||
tips_mgr: config_mgr.ConfigManager = None
|
||||
|
||||
db_mgr: database_mgr.DatabaseManager = None
|
||||
|
||||
ctr_mgr: center_mgr.V2CenterAPI = None
|
||||
|
||||
logger: logging.Logger = None
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
await self.im_mgr.initialize()
|
||||
|
||||
async def run(self):
|
||||
# TODO make it async
|
||||
plugin_host.initialize_plugins()
|
||||
|
||||
await self.im_mgr.run()
|
||||
@@ -1,54 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import colorlog
|
||||
|
||||
|
||||
log_colors_config = {
|
||||
'DEBUG': 'green', # cyan white
|
||||
'INFO': 'white',
|
||||
'WARNING': 'yellow',
|
||||
'ERROR': 'red',
|
||||
'CRITICAL': 'cyan',
|
||||
}
|
||||
|
||||
|
||||
async def init_logging() -> logging.Logger:
|
||||
|
||||
level = logging.INFO
|
||||
|
||||
if 'DEBUG' in os.environ and os.environ['DEBUG'] in ['true', '1']:
|
||||
level = logging.DEBUG
|
||||
|
||||
log_file_name = "logs/qcg-%s.log" % time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
||||
|
||||
qcg_logger = logging.getLogger("qcg")
|
||||
|
||||
qcg_logger.setLevel(level)
|
||||
|
||||
log_handlers: logging.Handler = [
|
||||
logging.StreamHandler(sys.stdout),
|
||||
logging.FileHandler(log_file_name)
|
||||
]
|
||||
|
||||
for handler in log_handlers:
|
||||
handler.setLevel(level)
|
||||
handler.setFormatter(
|
||||
colorlog.ColoredFormatter(
|
||||
fmt="[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
log_colors=log_colors_config
|
||||
)
|
||||
)
|
||||
qcg_logger.addHandler(handler)
|
||||
|
||||
logging.basicConfig(level=level, # 设置日志输出格式
|
||||
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
|
||||
# 日志输出的格式
|
||||
# -8表示占位符,让输出左对齐,输出长度都为8位
|
||||
datefmt="%Y-%m-%d %H:%M:%S" # 时间输出的格式
|
||||
)
|
||||
|
||||
return qcg_logger
|
||||
108
pkg/command/cmdmgr.py
Normal file
108
pkg/command/cmdmgr.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from ..openai import entities as llm_entities
|
||||
from ..openai.session import entities as session_entities
|
||||
from . import entities, operator, errors
|
||||
|
||||
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cfg, cmd, help, version, update
|
||||
|
||||
|
||||
class CommandManager:
|
||||
"""命令管理器
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
cmd_list: list[operator.CommandOperator]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
# 实例化所有类
|
||||
self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators]
|
||||
|
||||
# 设置所有类的子节点
|
||||
for cmd in self.cmd_list:
|
||||
cmd.children = [child for child in self.cmd_list if child.parent_class == cmd.__class__]
|
||||
|
||||
# 初始化所有类
|
||||
for cmd in self.cmd_list:
|
||||
await cmd.initialize()
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
context: entities.ExecuteContext,
|
||||
operator_list: list[operator.CommandOperator],
|
||||
operator: operator.CommandOperator = None
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
"""执行命令
|
||||
"""
|
||||
|
||||
found = False
|
||||
if len(context.crt_params) > 0:
|
||||
for oper in operator_list:
|
||||
if (context.crt_params[0] == oper.name \
|
||||
or context.crt_params[0] in oper.alias) \
|
||||
and (oper.parent_class is None or oper.parent_class == operator.__class__):
|
||||
found = True
|
||||
|
||||
context.crt_command = context.crt_params[0]
|
||||
context.crt_params = context.crt_params[1:]
|
||||
|
||||
async for ret in self._execute(
|
||||
context,
|
||||
oper.children,
|
||||
oper
|
||||
):
|
||||
yield ret
|
||||
break
|
||||
|
||||
if not found:
|
||||
if operator is None:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandNotFoundError(context.crt_params[0])
|
||||
)
|
||||
else:
|
||||
if operator.lowest_privilege > context.privilege:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandPrivilegeError(operator.name)
|
||||
)
|
||||
else:
|
||||
async for ret in operator.execute(context):
|
||||
yield ret
|
||||
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
command_text: str,
|
||||
query: core_entities.Query,
|
||||
session: session_entities.Session
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
"""执行命令
|
||||
"""
|
||||
|
||||
privilege = 1
|
||||
if query.sender_id == self.ap.cfg_mgr.data['admin_qq'] \
|
||||
or query.sender_id in self.ap.cfg_mgr['admin_qq']:
|
||||
privilege = 2
|
||||
|
||||
ctx = entities.ExecuteContext(
|
||||
query=query,
|
||||
session=session,
|
||||
command_text=command_text,
|
||||
command='',
|
||||
crt_command='',
|
||||
params=command_text.split(' '),
|
||||
crt_params=command_text.split(' '),
|
||||
privilege=privilege
|
||||
)
|
||||
|
||||
async for ret in self._execute(
|
||||
ctx,
|
||||
self.cmd_list
|
||||
):
|
||||
yield ret
|
||||
43
pkg/command/entities.py
Normal file
43
pkg/command/entities.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
import mirai
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from ..openai.session import entities as session_entities
|
||||
from . import errors, operator
|
||||
|
||||
|
||||
class CommandReturn(pydantic.BaseModel):
|
||||
|
||||
text: typing.Optional[str]
|
||||
"""文本
|
||||
"""
|
||||
|
||||
image: typing.Optional[mirai.Image]
|
||||
|
||||
error: typing.Optional[errors.CommandError]= None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ExecuteContext(pydantic.BaseModel):
|
||||
|
||||
query: core_entities.Query
|
||||
|
||||
session: session_entities.Session
|
||||
|
||||
command_text: str
|
||||
|
||||
command: str
|
||||
|
||||
crt_command: str
|
||||
|
||||
params: list[str]
|
||||
|
||||
crt_params: list[str]
|
||||
|
||||
privilege: int
|
||||
33
pkg/command/errors.py
Normal file
33
pkg/command/errors.py
Normal file
@@ -0,0 +1,33 @@
|
||||
|
||||
|
||||
class CommandError(Exception):
|
||||
|
||||
def __init__(self, message: str = None):
|
||||
self.message = message
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
class CommandNotFoundError(CommandError):
|
||||
|
||||
def __init__(self, message: str = None):
|
||||
super().__init__("未知命令: "+message)
|
||||
|
||||
|
||||
class CommandPrivilegeError(CommandError):
|
||||
|
||||
def __init__(self, message: str = None):
|
||||
super().__init__("权限不足: "+message)
|
||||
|
||||
|
||||
class ParamNotEnoughError(CommandError):
|
||||
|
||||
def __init__(self, message: str = None):
|
||||
super().__init__("参数不足: "+message)
|
||||
|
||||
|
||||
class CommandOperationError(CommandError):
|
||||
|
||||
def __init__(self, message: str = None):
|
||||
super().__init__("操作失败: "+message)
|
||||
75
pkg/command/operator.py
Normal file
75
pkg/command/operator.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import abc
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from ..openai.session import entities as session_entities
|
||||
from . import entities
|
||||
|
||||
|
||||
preregistered_operators: list[typing.Type[CommandOperator]] = []
|
||||
|
||||
|
||||
def operator_class(
|
||||
name: str,
|
||||
help: str,
|
||||
usage: str = None,
|
||||
alias: list[str] = [],
|
||||
privilege: int=1, # 1为普通用户,2为管理员
|
||||
parent_class: typing.Type[CommandOperator] = None
|
||||
) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]:
|
||||
def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]:
|
||||
cls.name = name
|
||||
cls.alias = alias
|
||||
cls.help = help
|
||||
cls.usage = usage
|
||||
cls.parent_class = parent_class
|
||||
|
||||
preregistered_operators.append(cls)
|
||||
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class CommandOperator(metaclass=abc.ABCMeta):
|
||||
"""命令算子
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
name: str
|
||||
"""名称,搜索到时若符合则使用"""
|
||||
|
||||
alias: list[str]
|
||||
"""同name"""
|
||||
|
||||
help: str
|
||||
"""此节点的帮助信息"""
|
||||
|
||||
usage: str = None
|
||||
|
||||
parent_class: typing.Type[CommandOperator] | None = None
|
||||
"""父节点类。标记以供管理器在初始化时编织父子关系。"""
|
||||
|
||||
lowest_privilege: int = 0
|
||||
"""最低权限。若权限低于此值,则不予执行。"""
|
||||
|
||||
children: list[CommandOperator]
|
||||
"""子节点。解析命令时,若节点有子节点,则以下一个参数去匹配子节点,
|
||||
若有匹配中的,转移到子节点中执行,若没有匹配中的或没有子节点,执行此节点。"""
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.children = []
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
pass
|
||||
98
pkg/command/operators/cfg.py
Normal file
98
pkg/command/operators/cfg.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import json
|
||||
|
||||
from .. import operator, entities, cmdmgr, errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="cfg",
|
||||
help="配置项管理",
|
||||
usage='!cfg <配置项> [配置值]\n!cfg all',
|
||||
privilege=2
|
||||
)
|
||||
class CfgOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
"""执行
|
||||
"""
|
||||
reply = ''
|
||||
|
||||
params = context.crt_params
|
||||
cfg_mgr = self.ap.cfg_mgr
|
||||
|
||||
false = False
|
||||
true = True
|
||||
|
||||
reply_str = ""
|
||||
if len(params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供配置项名称'))
|
||||
else:
|
||||
cfg_name = params[0]
|
||||
if cfg_name == 'all':
|
||||
reply_str = "[bot]所有配置项:\n\n"
|
||||
for cfg in cfg_mgr.data.keys():
|
||||
if not cfg.startswith('__') and not cfg == 'logging':
|
||||
# 根据配置项类型进行格式化,如果是字典则转换为json并格式化
|
||||
if isinstance(cfg_mgr.data[cfg], str):
|
||||
reply_str += "{}: \"{}\"\n".format(cfg, cfg_mgr.data[cfg])
|
||||
elif isinstance(cfg_mgr.data[cfg], dict):
|
||||
# 不进行unicode转义,并格式化
|
||||
reply_str += "{}: {}\n".format(cfg,
|
||||
json.dumps(cfg_mgr.data[cfg],
|
||||
ensure_ascii=False, indent=4))
|
||||
else:
|
||||
reply_str += "{}: {}\n".format(cfg, cfg_mgr.data[cfg])
|
||||
yield entities.CommandReturn(text=reply_str)
|
||||
else:
|
||||
cfg_entry_path = cfg_name.split('.')
|
||||
|
||||
try:
|
||||
if len(params) == 1: # 未指定配置值,返回配置项值
|
||||
cfg_entry = cfg_mgr.data[cfg_entry_path[0]]
|
||||
if len(cfg_entry_path) > 1:
|
||||
for i in range(1, len(cfg_entry_path)):
|
||||
cfg_entry = cfg_entry[cfg_entry_path[i]]
|
||||
|
||||
if isinstance(cfg_entry, str):
|
||||
reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, cfg_entry)
|
||||
elif isinstance(cfg_entry, dict):
|
||||
reply_str = "[bot]配置项{}: {}\n".format(cfg_name,
|
||||
json.dumps(cfg_entry,
|
||||
ensure_ascii=False, indent=4))
|
||||
else:
|
||||
reply_str = "[bot]配置项{}: {}\n".format(cfg_name, cfg_entry)
|
||||
|
||||
yield entities.CommandReturn(text=reply_str)
|
||||
else:
|
||||
cfg_value = " ".join(params[1:])
|
||||
|
||||
cfg_value = eval(cfg_value)
|
||||
|
||||
cfg_entry = cfg_mgr.data[cfg_entry_path[0]]
|
||||
if len(cfg_entry_path) > 1:
|
||||
for i in range(1, len(cfg_entry_path) - 1):
|
||||
cfg_entry = cfg_entry[cfg_entry_path[i]]
|
||||
if isinstance(cfg_entry[cfg_entry_path[-1]], type(cfg_value)):
|
||||
cfg_entry[cfg_entry_path[-1]] = cfg_value
|
||||
yield entities.CommandReturn(text="配置项{}修改成功".format(cfg_name))
|
||||
else:
|
||||
# reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)]
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError("配置项{}类型不匹配".format(cfg_name)))
|
||||
else:
|
||||
cfg_mgr.data[cfg_entry_path[0]] = cfg_value
|
||||
# reply = ["[bot]配置项{}修改成功".format(cfg_name)]
|
||||
yield entities.CommandReturn(text="配置项{}修改成功".format(cfg_name))
|
||||
except KeyError:
|
||||
# reply = ["[bot]err:未找到配置项 {}".format(cfg_name)]
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError("未找到配置项 {}".format(cfg_name)))
|
||||
except NameError:
|
||||
# reply = ["[bot]err:值{}不合法(字符串需要使用双引号包裹)".format(cfg_value)]
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError("值{}不合法(字符串需要使用双引号包裹)".format(cfg_value)))
|
||||
except ValueError:
|
||||
# reply = ["[bot]err:未找到配置项 {}".format(cfg_name)]
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError("未找到配置项 {}".format(cfg_name)))
|
||||
50
pkg/command/operators/cmd.py
Normal file
50
pkg/command/operators/cmd.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities, cmdmgr, errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="cmd",
|
||||
help='显示命令列表',
|
||||
usage='!cmd\n!cmd <命令名称>'
|
||||
)
|
||||
class CmdOperator(operator.CommandOperator):
|
||||
"""命令列表
|
||||
"""
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
"""执行
|
||||
"""
|
||||
if len(context.crt_params) == 0:
|
||||
reply_str = "当前所有命令: \n\n"
|
||||
|
||||
for cmd in self.ap.cmd_mgr.cmd_list:
|
||||
if cmd.parent_class is None:
|
||||
reply_str += f"{cmd.name}: {cmd.help}\n"
|
||||
|
||||
reply_str += "\n使用 !cmd <命令名称> 查看命令的详细帮助"
|
||||
|
||||
yield entities.CommandReturn(text=reply_str.strip())
|
||||
|
||||
else:
|
||||
cmd_name = context.crt_params[0]
|
||||
|
||||
cmd = None
|
||||
|
||||
for _cmd in self.ap.cmd_mgr.cmd_list:
|
||||
if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (_cmd.parent_class is None):
|
||||
cmd = _cmd
|
||||
break
|
||||
|
||||
if cmd is None:
|
||||
yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name))
|
||||
else:
|
||||
reply_str = f"{cmd.name}: {cmd.help}\n\n"
|
||||
reply_str += f"使用方法: \n{cmd.usage}"
|
||||
|
||||
yield entities.CommandReturn(text=reply_str.strip())
|
||||
62
pkg/command/operators/default.py
Normal file
62
pkg/command/operators/default.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
from .. import operator, entities, cmdmgr, errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="default",
|
||||
help="操作情景预设",
|
||||
usage='!default\n!default set <指定情景预设为默认>'
|
||||
)
|
||||
class DefaultOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
reply_str = "当前所有情景预设: \n\n"
|
||||
|
||||
for prompt in self.ap.prompt_mgr.get_all_prompts():
|
||||
|
||||
content = ""
|
||||
for msg in prompt.messages:
|
||||
content += f" {msg.role}: {msg.content}"
|
||||
|
||||
reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n"
|
||||
|
||||
reply_str += f"当前会话使用的是: {context.session.use_prompt_name}"
|
||||
|
||||
yield entities.CommandReturn(text=reply_str.strip())
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="set",
|
||||
help="设置当前会话默认情景预设",
|
||||
parent_class=DefaultOperator
|
||||
)
|
||||
class DefaultSetOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称'))
|
||||
else:
|
||||
prompt_name = context.crt_params[0]
|
||||
|
||||
try:
|
||||
prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name)
|
||||
if prompt is None:
|
||||
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name)))
|
||||
else:
|
||||
context.session.use_prompt_name = prompt.name
|
||||
yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e)))
|
||||
62
pkg/command/operators/delc.py
Normal file
62
pkg/command/operators/delc.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import datetime
|
||||
|
||||
from .. import operator, entities, cmdmgr, errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="del",
|
||||
help="删除当前会话的历史记录",
|
||||
usage='!del <序号>\n!del all'
|
||||
)
|
||||
class DelOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
if context.session.conversations:
|
||||
delete_index = 0
|
||||
if len(context.crt_params) > 0:
|
||||
try:
|
||||
delete_index = int(context.crt_params[0])
|
||||
except:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数'))
|
||||
return
|
||||
|
||||
if delete_index < 0 or delete_index >= len(context.session.conversations):
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围'))
|
||||
return
|
||||
|
||||
# 倒序
|
||||
to_delete_index = len(context.session.conversations)-1-delete_index
|
||||
|
||||
if context.session.conversations[to_delete_index] == context.session.using_conversation:
|
||||
context.session.using_conversation = None
|
||||
|
||||
del context.session.conversations[to_delete_index]
|
||||
|
||||
yield entities.CommandReturn(text=f"已删除对话: {delete_index}")
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="all",
|
||||
help="删除此会话的所有历史记录",
|
||||
parent_class=DelOperator
|
||||
)
|
||||
class DelAllOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
context.session.conversations = []
|
||||
context.session.using_conversation = None
|
||||
|
||||
yield entities.CommandReturn(text="已删除所有对话")
|
||||
25
pkg/command/operators/func.py
Normal file
25
pkg/command/operators/func.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from .. import operator, entities, cmdmgr
|
||||
from ...plugin import host as plugin_host
|
||||
|
||||
|
||||
@operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func')
|
||||
class FuncOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> AsyncGenerator[entities.CommandReturn, None]:
|
||||
reply_str = "当前已加载的内容函数: \n\n"
|
||||
|
||||
index = 1
|
||||
for func in self.ap.tool_mgr.all_functions:
|
||||
reply_str += "{}. {}{}:\n{}\n\n".format(
|
||||
index,
|
||||
("(已禁用) " if not func.enable else ""),
|
||||
func.name,
|
||||
func.description,
|
||||
)
|
||||
index += 1
|
||||
|
||||
yield entities.CommandReturn(text=reply_str)
|
||||
23
pkg/command/operators/help.py
Normal file
23
pkg/command/operators/help.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities, cmdmgr, errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name='help',
|
||||
help='显示帮助',
|
||||
usage='!help\n!help <命令名称>'
|
||||
)
|
||||
class HelpOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
help = self.ap.tips_mgr.data['help_message']
|
||||
|
||||
help += '\n发送命令 !cmd 可查看命令列表'
|
||||
|
||||
yield entities.CommandReturn(text=help)
|
||||
36
pkg/command/operators/last.py
Normal file
36
pkg/command/operators/last.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import datetime
|
||||
|
||||
|
||||
from .. import operator, entities, cmdmgr, errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="last",
|
||||
help="切换到前一个对话",
|
||||
usage='!last'
|
||||
)
|
||||
class LastOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
if context.session.conversations:
|
||||
# 找到当前会话的上一个会话
|
||||
for index in range(len(context.session.conversations)-1, -1, -1):
|
||||
if context.session.conversations[index] == context.session.using_conversation:
|
||||
if index == 0:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了'))
|
||||
return
|
||||
else:
|
||||
context.session.using_conversation = context.session.conversations[index-1]
|
||||
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}")
|
||||
return
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
||||
56
pkg/command/operators/list.py
Normal file
56
pkg/command/operators/list.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import datetime
|
||||
|
||||
from .. import operator, entities, cmdmgr, errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="list",
|
||||
help="列出此会话中的所有历史对话",
|
||||
usage='!list\n!list <页码>'
|
||||
)
|
||||
class ListOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
page = 0
|
||||
|
||||
if len(context.crt_params) > 0:
|
||||
try:
|
||||
page = int(context.crt_params[0]-1)
|
||||
except:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数'))
|
||||
return
|
||||
|
||||
record_per_page = 10
|
||||
|
||||
content = ''
|
||||
|
||||
index = 0
|
||||
|
||||
using_conv_index = 0
|
||||
|
||||
for conv in context.session.conversations[::-1]:
|
||||
time_str = conv.create_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
if conv == context.session.using_conversation:
|
||||
using_conv_index = index
|
||||
|
||||
if index >= page * record_per_page and index < (page + 1) * record_per_page:
|
||||
content += f"{index} {time_str}: {conv.messages[0].content}\n"
|
||||
index += 1
|
||||
|
||||
if content == '':
|
||||
content = '无'
|
||||
else:
|
||||
if context.session.using_conversation is None:
|
||||
content += "\n当前处于新会话"
|
||||
else:
|
||||
content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].content}"
|
||||
|
||||
yield entities.CommandReturn(text=f"第 {page + 1} 页 (时间倒序):\n{content}")
|
||||
35
pkg/command/operators/next.py
Normal file
35
pkg/command/operators/next.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import datetime
|
||||
|
||||
from .. import operator, entities, cmdmgr, errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="next",
|
||||
help="切换到后一个对话",
|
||||
usage='!next'
|
||||
)
|
||||
class NextOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
if context.session.conversations:
|
||||
# 找到当前会话的下一个会话
|
||||
for index in range(len(context.session.conversations)):
|
||||
if context.session.conversations[index] == context.session.using_conversation:
|
||||
if index == len(context.session.conversations)-1:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了'))
|
||||
return
|
||||
else:
|
||||
context.session.using_conversation = context.session.conversations[index+1]
|
||||
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
yield entities.CommandReturn(text=f"已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}")
|
||||
return
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
||||
239
pkg/command/operators/plugin.py
Normal file
239
pkg/command/operators/plugin.py
Normal file
@@ -0,0 +1,239 @@
|
||||
from __future__ import annotations
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
from .. import operator, entities, cmdmgr, errors
|
||||
from ...plugin import host as plugin_host
|
||||
from ...utils import updater
|
||||
from ...core import app
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="plugin",
|
||||
help="插件操作",
|
||||
usage="!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>"
|
||||
)
|
||||
class PluginOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
plugin_list = plugin_host.__plugins__
|
||||
reply_str = "所有插件({}):\n".format(len(plugin_host.__plugins__))
|
||||
idx = 0
|
||||
for key in plugin_host.iter_plugins_name():
|
||||
plugin = plugin_list[key]
|
||||
reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\
|
||||
.format((idx+1), plugin['name'],
|
||||
"[已禁用]" if not plugin['enabled'] else "",
|
||||
plugin['description'],
|
||||
plugin['version'], plugin['author'])
|
||||
|
||||
# TODO 从元数据调远程地址
|
||||
# if updater.is_repo("/".join(plugin['path'].split('/')[:-1])):
|
||||
# remote_url = updater.get_remote_url("/".join(plugin['path'].split('/')[:-1]))
|
||||
# if remote_url != "https://github.com/RockChinQ/QChatGPT" and remote_url != "https://gitee.com/RockChin/QChatGPT":
|
||||
# reply_str += "源码: "+remote_url+"\n"
|
||||
|
||||
idx += 1
|
||||
|
||||
yield entities.CommandReturn(text=reply_str)
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="get",
|
||||
help="安装插件",
|
||||
privilege=2,
|
||||
parent_class=PluginOperator
|
||||
)
|
||||
class PluginGetOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址'))
|
||||
else:
|
||||
repo = context.crt_params[0]
|
||||
|
||||
yield entities.CommandReturn(text="正在安装插件...")
|
||||
|
||||
try:
|
||||
plugin_host.install_plugin(repo)
|
||||
yield entities.CommandReturn(text="插件安装成功,请重启程序以加载插件")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError("插件安装失败: "+str(e)))
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="update",
|
||||
help="更新插件",
|
||||
privilege=2,
|
||||
parent_class=PluginOperator
|
||||
)
|
||||
class PluginUpdateOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
||||
else:
|
||||
plugin_name = context.crt_params[0]
|
||||
|
||||
try:
|
||||
plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(plugin_name)
|
||||
|
||||
if plugin_path_name is not None:
|
||||
yield entities.CommandReturn(text="正在更新插件...")
|
||||
plugin_host.update_plugin(plugin_name)
|
||||
yield entities.CommandReturn(text="插件更新成功,请重启程序以加载插件")
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: 未找到插件"))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
|
||||
|
||||
@operator.operator_class(
|
||||
name="all",
|
||||
help="更新所有插件",
|
||||
privilege=2,
|
||||
parent_class=PluginUpdateOperator
|
||||
)
|
||||
class PluginUpdateAllOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
try:
|
||||
plugins = []
|
||||
|
||||
for key in plugin_host.__plugins__:
|
||||
plugins.append(key)
|
||||
|
||||
if plugins:
|
||||
yield entities.CommandReturn(text="正在更新插件...")
|
||||
updated = []
|
||||
try:
|
||||
for plugin_name in plugins:
|
||||
plugin_host.update_plugin(plugin_name)
|
||||
updated.append(plugin_name)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
|
||||
yield entities.CommandReturn(text="已更新插件: {}".format(", ".join(updated)))
|
||||
else:
|
||||
yield entities.CommandReturn(text="没有可更新的插件")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="del",
|
||||
help="删除插件",
|
||||
privilege=2,
|
||||
parent_class=PluginOperator
|
||||
)
|
||||
class PluginDelOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
||||
else:
|
||||
plugin_name = context.crt_params[0]
|
||||
|
||||
try:
|
||||
plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(plugin_name)
|
||||
|
||||
if plugin_path_name is not None:
|
||||
yield entities.CommandReturn(text="正在删除插件...")
|
||||
plugin_host.uninstall_plugin(plugin_name)
|
||||
yield entities.CommandReturn(text="插件删除成功,请重启程序以加载插件")
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: 未找到插件"))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: "+str(e)))
|
||||
|
||||
|
||||
def update_plugin_status(plugin_name: str, new_status: bool, ap: app.Application):
|
||||
if plugin_name in plugin_host.__plugins__:
|
||||
plugin_host.__plugins__[plugin_name]['enabled'] = new_status
|
||||
|
||||
for func in ap.tool_mgr.all_functions:
|
||||
if func.name.startswith(plugin_name+'-'):
|
||||
func.enable = new_status
|
||||
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="on",
|
||||
help="启用插件",
|
||||
privilege=2,
|
||||
parent_class=PluginOperator
|
||||
)
|
||||
class PluginEnableOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
||||
else:
|
||||
plugin_name = context.crt_params[0]
|
||||
|
||||
try:
|
||||
if update_plugin_status(plugin_name, True, self.ap):
|
||||
yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name))
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e)))
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="off",
|
||||
help="禁用插件",
|
||||
privilege=2,
|
||||
parent_class=PluginOperator
|
||||
)
|
||||
class PluginDisableOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
||||
else:
|
||||
plugin_name = context.crt_params[0]
|
||||
|
||||
try:
|
||||
if update_plugin_status(plugin_name, False, self.ap):
|
||||
yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name))
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e)))
|
||||
29
pkg/command/operators/prompt.py
Normal file
29
pkg/command/operators/prompt.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities, cmdmgr, errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="prompt",
|
||||
help="查看当前对话的前文",
|
||||
usage='!prompt'
|
||||
)
|
||||
class PromptOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
"""执行
|
||||
"""
|
||||
if context.session.using_conversation is None:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
||||
else:
|
||||
reply_str = '当前对话所有内容:\n\n'
|
||||
|
||||
for msg in context.session.using_conversation.messages:
|
||||
reply_str += f"{msg.role}: {msg.content}\n"
|
||||
|
||||
yield entities.CommandReturn(text=reply_str)
|
||||
34
pkg/command/operators/resend.py
Normal file
34
pkg/command/operators/resend.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities, cmdmgr, errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="resend",
|
||||
help="重发当前会话的最后一条消息",
|
||||
usage='!resend'
|
||||
)
|
||||
class ResendOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
# 回滚到最后一条用户message前
|
||||
if context.session.using_conversation is None:
|
||||
yield entities.CommandReturn(error=errors.CommandError("当前没有对话"))
|
||||
else:
|
||||
conv_msg = context.session.using_conversation.messages
|
||||
|
||||
# 倒序一直删到最后一条用户message
|
||||
while len(conv_msg) > 0 and conv_msg[-1].role != 'user':
|
||||
conv_msg.pop()
|
||||
|
||||
if len(conv_msg) > 0:
|
||||
# 删除最后一条用户message
|
||||
conv_msg.pop()
|
||||
|
||||
# 不重发了,提示用户已删除就行了
|
||||
yield entities.CommandReturn(text="已删除最后一次请求记录")
|
||||
23
pkg/command/operators/reset.py
Normal file
23
pkg/command/operators/reset.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities, cmdmgr, errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="reset",
|
||||
help="重置当前会话",
|
||||
usage='!reset'
|
||||
)
|
||||
class ResetOperator(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
"""执行
|
||||
"""
|
||||
context.session.using_conversation = None
|
||||
|
||||
yield entities.CommandReturn(text="已重置当前会话")
|
||||
31
pkg/command/operators/update.py
Normal file
31
pkg/command/operators/update.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
from .. import operator, entities, cmdmgr, errors
|
||||
from ...utils import updater
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="update",
|
||||
help="更新程序",
|
||||
usage='!update',
|
||||
privilege=2
|
||||
)
|
||||
class UpdateCommand(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
try:
|
||||
yield entities.CommandReturn(text="正在进行更新...")
|
||||
if updater.update_all():
|
||||
yield entities.CommandReturn(text="更新完成,请重启程序以应用更新")
|
||||
else:
|
||||
yield entities.CommandReturn(text="当前已是最新版本")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError("更新失败: "+str(e)))
|
||||
28
pkg/command/operators/version.py
Normal file
28
pkg/command/operators/version.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, cmdmgr, entities, errors
|
||||
from ...utils import updater
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="version",
|
||||
help="显示版本信息",
|
||||
usage='!version'
|
||||
)
|
||||
class VersionCommand(operator.CommandOperator):
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
reply_str = f"当前版本: \n{updater.get_current_version_info()}"
|
||||
|
||||
try:
|
||||
if updater.is_new_version_available():
|
||||
reply_str += "\n\n有新版本可用, 使用 !update 更新"
|
||||
except:
|
||||
pass
|
||||
|
||||
yield entities.CommandReturn(text=reply_str.strip())
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import model as file_model
|
||||
from ..utils import context
|
||||
from .impls import pymodule, json as json_file
|
||||
|
||||
|
||||
|
||||
71
pkg/core/app.py
Normal file
71
pkg/core/app.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from ..qqbot import manager as qqbot_mgr
|
||||
from ..openai.session import sessionmgr as llm_session_mgr
|
||||
from ..openai.requester import modelmgr as llm_model_mgr
|
||||
from ..openai.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ..openai.tools import toolmgr as llm_tool_mgr
|
||||
from ..config import manager as config_mgr
|
||||
from ..database import manager as database_mgr
|
||||
from ..utils.center import v2 as center_mgr
|
||||
from ..command import cmdmgr
|
||||
from ..plugin import host as plugin_host
|
||||
from . import pool, controller
|
||||
from ..pipeline import stagemgr
|
||||
|
||||
|
||||
class Application:
|
||||
im_mgr: qqbot_mgr.QQBotManager = None
|
||||
|
||||
cmd_mgr: cmdmgr.CommandManager = None
|
||||
|
||||
sess_mgr: llm_session_mgr.SessionManager = None
|
||||
|
||||
model_mgr: llm_model_mgr.ModelManager = None
|
||||
|
||||
prompt_mgr: llm_prompt_mgr.PromptManager = None
|
||||
|
||||
tool_mgr: llm_tool_mgr.ToolManager = None
|
||||
|
||||
cfg_mgr: config_mgr.ConfigManager = None
|
||||
|
||||
tips_mgr: config_mgr.ConfigManager = None
|
||||
|
||||
db_mgr: database_mgr.DatabaseManager = None
|
||||
|
||||
ctr_mgr: center_mgr.V2CenterAPI = None
|
||||
|
||||
query_pool: pool.QueryPool = None
|
||||
|
||||
ctrl: controller.Controller = None
|
||||
|
||||
stage_mgr: stagemgr.StageManager = None
|
||||
|
||||
logger: logging.Logger = None
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
plugin_host.initialize_plugins()
|
||||
|
||||
# 把现有的所有内容函数加到toolmgr里
|
||||
for func in plugin_host.__callable_functions__:
|
||||
self.tool_mgr.register_legacy_function(
|
||||
name=func['name'],
|
||||
description=func['description'],
|
||||
parameters=func['parameters'],
|
||||
func=plugin_host.__function_inst_map__[func['name']]
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(self.im_mgr.run()),
|
||||
asyncio.create_task(self.ctrl.run())
|
||||
]
|
||||
|
||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
@@ -3,19 +3,23 @@ from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
|
||||
from . import files
|
||||
from . import deps
|
||||
from . import log
|
||||
from . import config
|
||||
from .bootutils import files
|
||||
from .bootutils import deps
|
||||
from .bootutils import log
|
||||
from .bootutils import config
|
||||
|
||||
from . import app
|
||||
from . import pool
|
||||
from . import controller
|
||||
from ..pipeline import stagemgr
|
||||
from ..audit import identifier
|
||||
from ..database import manager as db_mgr
|
||||
from ..openai import manager as llm_mgr
|
||||
from ..openai import session as llm_session
|
||||
from ..openai import dprompt as llm_dprompt
|
||||
from ..openai.session import sessionmgr as llm_session_mgr
|
||||
from ..openai.requester import modelmgr as llm_model_mgr
|
||||
from ..openai.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ..openai.tools import toolmgr as llm_tool_mgr
|
||||
from ..qqbot import manager as im_mgr
|
||||
from ..qqbot.cmds import aamgr as im_cmd_aamgr
|
||||
from ..command import cmdmgr
|
||||
from ..plugin import host as plugin_host
|
||||
from ..utils.center import v2 as center_v2
|
||||
from ..utils import updater
|
||||
@@ -75,17 +79,14 @@ async def make_app() -> app.Application:
|
||||
if cfg_mgr.data['admin_qq'] == 0:
|
||||
qcg_logger.warning("未设置管理员QQ号,将无法使用管理员命令,请在 config.py 中修改 admin_qq")
|
||||
|
||||
# TODO make it async
|
||||
llm_dprompt.register_all()
|
||||
im_cmd_aamgr.register_all()
|
||||
im_cmd_aamgr.apply_privileges()
|
||||
|
||||
# 构建组建实例
|
||||
ap = app.Application()
|
||||
ap.logger = qcg_logger
|
||||
ap.cfg_mgr = cfg_mgr
|
||||
ap.tips_mgr = tips_mgr
|
||||
|
||||
ap.query_pool = pool.QueryPool()
|
||||
|
||||
center_v2_api = center_v2.V2CenterAPI(
|
||||
basic_info={
|
||||
"host_id": identifier.identifier['host_id'],
|
||||
@@ -105,22 +106,45 @@ async def make_app() -> app.Application:
|
||||
db_mgr_inst.initialize_database()
|
||||
ap.db_mgr = db_mgr_inst
|
||||
|
||||
llm_mgr_inst = llm_mgr.OpenAIInteract(ap)
|
||||
ap.llm_mgr = llm_mgr_inst
|
||||
# TODO make it async
|
||||
llm_session.load_sessions()
|
||||
cmd_mgr_inst = cmdmgr.CommandManager(ap)
|
||||
await cmd_mgr_inst.initialize()
|
||||
ap.cmd_mgr = cmd_mgr_inst
|
||||
|
||||
im_mgr_inst = im_mgr.QQBotManager(first_time_init=True, ap=ap)
|
||||
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
|
||||
await llm_model_mgr_inst.initialize()
|
||||
ap.model_mgr = llm_model_mgr_inst
|
||||
|
||||
llm_session_mgr_inst = llm_session_mgr.SessionManager(ap)
|
||||
await llm_session_mgr_inst.initialize()
|
||||
ap.sess_mgr = llm_session_mgr_inst
|
||||
|
||||
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap)
|
||||
await llm_prompt_mgr_inst.initialize()
|
||||
ap.prompt_mgr = llm_prompt_mgr_inst
|
||||
|
||||
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
|
||||
await llm_tool_mgr_inst.initialize()
|
||||
ap.tool_mgr = llm_tool_mgr_inst
|
||||
|
||||
im_mgr_inst = im_mgr.QQBotManager(ap=ap)
|
||||
await im_mgr_inst.initialize()
|
||||
ap.im_mgr = im_mgr_inst
|
||||
|
||||
stage_mgr = stagemgr.StageManager(ap)
|
||||
await stage_mgr.initialize()
|
||||
ap.stage_mgr = stage_mgr
|
||||
|
||||
ctrl = controller.Controller(ap)
|
||||
ap.ctrl = ctrl
|
||||
|
||||
# TODO make it async
|
||||
plugin_host.load_plugins()
|
||||
# plugin_host.initialize_plugins()
|
||||
|
||||
await ap.initialize()
|
||||
|
||||
return ap
|
||||
|
||||
|
||||
async def main():
|
||||
app_inst = await make_app()
|
||||
await app_inst.initialize()
|
||||
await app_inst.run()
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from ..config import manager as config_mgr
|
||||
from ..config.impls import pymodule
|
||||
from ...config import manager as config_mgr
|
||||
from ...config.impls import pymodule
|
||||
|
||||
|
||||
load_python_module_config = config_mgr.load_python_module_config
|
||||
56
pkg/core/bootutils/log.py
Normal file
56
pkg/core/bootutils/log.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import colorlog
|
||||
|
||||
|
||||
log_colors_config = {
|
||||
"DEBUG": "green", # cyan white
|
||||
"INFO": "white",
|
||||
"WARNING": "yellow",
|
||||
"ERROR": "red",
|
||||
"CRITICAL": "cyan",
|
||||
}
|
||||
|
||||
|
||||
async def init_logging() -> logging.Logger:
|
||||
level = logging.INFO
|
||||
|
||||
if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]:
|
||||
level = logging.DEBUG
|
||||
|
||||
log_file_name = "logs/qcg-%s.log" % time.strftime(
|
||||
"%Y-%m-%d-%H-%M-%S", time.localtime()
|
||||
)
|
||||
|
||||
qcg_logger = logging.getLogger("qcg")
|
||||
|
||||
qcg_logger.setLevel(level)
|
||||
|
||||
color_formatter = colorlog.ColoredFormatter(
|
||||
fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
log_colors=log_colors_config,
|
||||
)
|
||||
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
|
||||
log_handlers: logging.Handler = [stream_handler, logging.FileHandler(log_file_name)]
|
||||
|
||||
for handler in log_handlers:
|
||||
handler.setLevel(level)
|
||||
handler.setFormatter(color_formatter)
|
||||
qcg_logger.addHandler(handler)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, # 设置日志输出格式
|
||||
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
|
||||
# 日志输出的格式
|
||||
# -8表示占位符,让输出左对齐,输出长度都为8位
|
||||
datefmt="%Y-%m-%d %H:%M:%S", # 时间输出的格式
|
||||
handlers=[logging.NullHandler()],
|
||||
)
|
||||
|
||||
return qcg_logger
|
||||
154
pkg/core/controller.py
Normal file
154
pkg/core/controller.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
from . import app, entities
|
||||
from ..pipeline import entities as pipeline_entities
|
||||
|
||||
DEFAULT_QUERY_CONCURRENCY = 10
|
||||
|
||||
|
||||
class Controller:
|
||||
"""总控制器
|
||||
"""
|
||||
ap: app.Application
|
||||
|
||||
semaphore: asyncio.Semaphore = None
|
||||
"""请求并发控制信号量"""
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.semaphore = asyncio.Semaphore(DEFAULT_QUERY_CONCURRENCY)
|
||||
|
||||
async def consumer(self):
|
||||
"""事件处理循环
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
selected_query: entities.Query = None
|
||||
|
||||
# 取请求
|
||||
async with self.ap.query_pool:
|
||||
queries: list[entities.Query] = self.ap.query_pool.queries
|
||||
|
||||
for query in queries:
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
self.ap.logger.debug(f"Checking query {query} session {session}")
|
||||
|
||||
if not session.semaphore.locked():
|
||||
selected_query = query
|
||||
await session.semaphore.acquire()
|
||||
|
||||
break
|
||||
|
||||
if selected_query: # 找到了
|
||||
queries.remove(selected_query)
|
||||
else: # 没找到 说明:没有请求 或者 所有query对应的session都已达到并发上限
|
||||
await self.ap.query_pool.condition.wait()
|
||||
continue
|
||||
|
||||
if selected_query:
|
||||
async def _process_query(selected_query):
|
||||
async with self.semaphore: # 总并发上限
|
||||
await self.process_query(selected_query)
|
||||
|
||||
async with self.ap.query_pool:
|
||||
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
|
||||
# 通知其他协程,有新的请求可以处理了
|
||||
self.ap.query_pool.condition.notify_all()
|
||||
|
||||
asyncio.create_task(_process_query(selected_query))
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"事件处理循环出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def _check_output(self, result: pipeline_entities.StageProcessResult):
|
||||
"""检查输出
|
||||
"""
|
||||
if result.user_notice:
|
||||
await self.ap.im_mgr.send(
|
||||
result.user_notice
|
||||
)
|
||||
if result.debug_notice:
|
||||
self.ap.logger.debug(result.debug_notice)
|
||||
if result.console_notice:
|
||||
self.ap.logger.info(result.console_notice)
|
||||
|
||||
async def _execute_from_stage(
|
||||
self,
|
||||
stage_index: int,
|
||||
query: entities.Query,
|
||||
):
|
||||
"""从指定阶段开始执行
|
||||
|
||||
如何看懂这里为什么这么写?
|
||||
去问 GPT-4:
|
||||
Q1: 现在有一个责任链,其中有多个stage,query对象在其中传递,stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None],
|
||||
如果返回的是生成器,需要挨个生成result,检查是否result中是否要求继续,如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了,就继续生成下一个result,
|
||||
调用后续的stage,直到该生成器全部生成完。责任链中可能有多个stage会返回生成器
|
||||
Q2: 不是这样的,你可能理解有误。如果我们责任链上有这些Stage:
|
||||
|
||||
A B C D E F G
|
||||
|
||||
如果所有的stage都返回Result,且所有Result都要求继续,那么执行顺序是:
|
||||
|
||||
A B C D E F G
|
||||
|
||||
现在假设C返回的是AsyncGenerator,那么执行顺序是:
|
||||
|
||||
A B C D E F G C D E F G C D E F G ...
|
||||
Q3: 但是如果不止一个stage会返回生成器呢?
|
||||
"""
|
||||
i = stage_index
|
||||
|
||||
while i < len(self.ap.stage_mgr.stage_containers):
|
||||
stage_container = self.ap.stage_mgr.stage_containers[i]
|
||||
|
||||
result = await stage_container.inst.process(query, stage_container.inst_name)
|
||||
|
||||
|
||||
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}")
|
||||
await self._check_output(result)
|
||||
|
||||
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
break
|
||||
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = result.new_query
|
||||
elif isinstance(result, typing.AsyncGenerator): # 生成器
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen")
|
||||
|
||||
async for sub_result in result:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}")
|
||||
await self._check_output(sub_result)
|
||||
|
||||
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
break
|
||||
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = sub_result.new_query
|
||||
await self._execute_from_stage(i + 1, query)
|
||||
break
|
||||
|
||||
i += 1
|
||||
|
||||
async def process_query(self, query: entities.Query):
|
||||
"""处理请求
|
||||
"""
|
||||
self.ap.logger.debug(f"Processing query {query}")
|
||||
|
||||
try:
|
||||
await self._execute_from_stage(0, query)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"处理请求时出错 {query}: {e}")
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
self.ap.logger.debug(f"Query {query} processed")
|
||||
|
||||
async def run(self):
|
||||
"""运行控制器
|
||||
"""
|
||||
await self.consumer()
|
||||
41
pkg/core/entities.py
Normal file
41
pkg/core/entities.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
import mirai
|
||||
|
||||
|
||||
class LauncherTypes(enum.Enum):
|
||||
|
||||
PERSON = 'person'
|
||||
"""私聊"""
|
||||
|
||||
GROUP = 'group'
|
||||
"""群聊"""
|
||||
|
||||
|
||||
class Query(pydantic.BaseModel):
|
||||
"""一次请求的信息封装"""
|
||||
|
||||
query_id: int
|
||||
"""请求ID"""
|
||||
|
||||
launcher_type: LauncherTypes
|
||||
"""会话类型"""
|
||||
|
||||
launcher_id: int
|
||||
"""会话ID"""
|
||||
|
||||
sender_id: int
|
||||
"""发送者ID"""
|
||||
|
||||
message_event: mirai.MessageEvent
|
||||
"""事件"""
|
||||
|
||||
message_chain: mirai.MessageChain
|
||||
"""消息链"""
|
||||
|
||||
resp_message_chain: typing.Optional[mirai.MessageChain] = None
|
||||
"""回复消息链"""
|
||||
52
pkg/core/pool.py
Normal file
52
pkg/core/pool.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import mirai
|
||||
|
||||
from . import entities
|
||||
|
||||
|
||||
class QueryPool:
|
||||
|
||||
query_id_counter: int = 0
|
||||
|
||||
pool_lock: asyncio.Lock
|
||||
|
||||
queries: list[entities.Query]
|
||||
|
||||
condition: asyncio.Condition
|
||||
|
||||
def __init__(self):
|
||||
self.query_id_counter = 0
|
||||
self.pool_lock = asyncio.Lock()
|
||||
self.queries = []
|
||||
self.condition = asyncio.Condition(self.pool_lock)
|
||||
|
||||
async def add_query(
|
||||
self,
|
||||
launcher_type: entities.LauncherTypes,
|
||||
launcher_id: int,
|
||||
sender_id: int,
|
||||
message_event: mirai.MessageEvent,
|
||||
message_chain: mirai.MessageChain
|
||||
) -> entities.Query:
|
||||
async with self.condition:
|
||||
query = entities.Query(
|
||||
query_id=self.query_id_counter,
|
||||
launcher_type=launcher_type,
|
||||
launcher_id=launcher_id,
|
||||
sender_id=sender_id,
|
||||
message_event=message_event,
|
||||
message_chain=message_chain
|
||||
)
|
||||
self.queries.append(query)
|
||||
self.query_id_counter += 1
|
||||
self.condition.notify_all()
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.pool_lock.acquire()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
self.pool_lock.release()
|
||||
@@ -1,134 +0,0 @@
|
||||
# 多情景预设值管理
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from ..utils import context
|
||||
|
||||
# __current__ = "default"
|
||||
# """当前默认使用的情景预设的名称
|
||||
|
||||
# 由管理员使用`!default <名称>`命令切换
|
||||
# """
|
||||
|
||||
# __prompts_from_files__ = {}
|
||||
# """从文件中读取的情景预设值"""
|
||||
|
||||
# __scenario_from_files__ = {}
|
||||
|
||||
|
||||
class ScenarioMode:
|
||||
"""情景预设模式抽象类"""
|
||||
|
||||
using_prompt_name = "default"
|
||||
"""新session创建时使用的prompt名称"""
|
||||
|
||||
prompts: dict[str, list] = {}
|
||||
|
||||
def __init__(self):
|
||||
logging.debug("prompts: {}".format(self.prompts))
|
||||
|
||||
def list(self) -> dict[str, list]:
|
||||
"""获取所有情景预设的名称及内容"""
|
||||
return self.prompts
|
||||
|
||||
def get_prompt(self, name: str) -> tuple[list, str]:
|
||||
"""获取指定情景预设的名称及内容"""
|
||||
for key in self.prompts:
|
||||
if key.startswith(name):
|
||||
return self.prompts[key], key
|
||||
raise Exception("没有找到情景预设: {}".format(name))
|
||||
|
||||
def set_using_name(self, name: str) -> str:
|
||||
"""设置默认情景预设"""
|
||||
for key in self.prompts:
|
||||
if key.startswith(name):
|
||||
self.using_prompt_name = key
|
||||
return key
|
||||
raise Exception("没有找到情景预设: {}".format(name))
|
||||
|
||||
def get_full_name(self, name: str) -> str:
|
||||
"""获取完整的情景预设名称"""
|
||||
for key in self.prompts:
|
||||
if key.startswith(name):
|
||||
return key
|
||||
raise Exception("没有找到情景预设: {}".format(name))
|
||||
|
||||
def get_using_name(self) -> str:
|
||||
"""获取默认情景预设"""
|
||||
return self.using_prompt_name
|
||||
|
||||
|
||||
class NormalScenarioMode(ScenarioMode):
|
||||
"""普通情景预设模式"""
|
||||
|
||||
def __init__(self):
|
||||
config = context.get_config_manager().data
|
||||
|
||||
# 加载config中的default_prompt值
|
||||
if type(config['default_prompt']) == str:
|
||||
self.using_prompt_name = "default"
|
||||
self.prompts = {"default": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": config['default_prompt']
|
||||
}
|
||||
]}
|
||||
|
||||
elif type(config['default_prompt']) == dict:
|
||||
for key in config['default_prompt']:
|
||||
self.prompts[key] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": config['default_prompt'][key]
|
||||
}
|
||||
]
|
||||
|
||||
# 从prompts/目录下的文件中载入
|
||||
# 遍历文件
|
||||
for file in os.listdir("prompts"):
|
||||
with open(os.path.join("prompts", file), encoding="utf-8") as f:
|
||||
self.prompts[file] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f.read()
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class FullScenarioMode(ScenarioMode):
|
||||
"""完整情景预设模式"""
|
||||
|
||||
def __init__(self):
|
||||
"""从json读取所有"""
|
||||
# 遍历scenario/目录下的所有文件,以文件名为键,文件内容中的prompt为值
|
||||
for file in os.listdir("scenario"):
|
||||
if file == "default-template.json":
|
||||
continue
|
||||
with open(os.path.join("scenario", file), encoding="utf-8") as f:
|
||||
self.prompts[file] = json.load(f)["prompt"]
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
scenario_mode_mapping = {}
|
||||
"""情景预设模式名称与对象的映射"""
|
||||
|
||||
|
||||
def register_all():
|
||||
"""注册所有情景预设模式,不使用装饰器,因为装饰器的方式不支持热重载"""
|
||||
global scenario_mode_mapping
|
||||
scenario_mode_mapping = {
|
||||
"normal": NormalScenarioMode(),
|
||||
"full_scenario": FullScenarioMode()
|
||||
}
|
||||
|
||||
|
||||
def mode_inst() -> ScenarioMode:
|
||||
"""获取指定名称的情景预设模式对象"""
|
||||
config = context.get_config_manager().data
|
||||
|
||||
if config['preset_mode'] == "default":
|
||||
config['preset_mode'] = "normal"
|
||||
|
||||
return scenario_mode_mapping[config['preset_mode']]
|
||||
33
pkg/openai/entities.py
Normal file
33
pkg/openai/entities.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import enum
|
||||
import pydantic
|
||||
|
||||
|
||||
class FunctionCall(pydantic.BaseModel):
|
||||
name: str
|
||||
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(pydantic.BaseModel):
|
||||
id: str
|
||||
|
||||
type: str
|
||||
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
class Message(pydantic.BaseModel):
|
||||
role: str
|
||||
|
||||
name: typing.Optional[str] = None
|
||||
|
||||
content: typing.Optional[str] = None
|
||||
|
||||
function_call: typing.Optional[FunctionCall] = None
|
||||
|
||||
tool_calls: typing.Optional[list[ToolCall]] = None
|
||||
|
||||
tool_call_id: typing.Optional[str] = None
|
||||
@@ -1,46 +0,0 @@
|
||||
# 封装了function calling的一些支持函数
|
||||
import logging
|
||||
|
||||
from ..plugin import host
|
||||
|
||||
|
||||
class ContentFunctionNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_func_schema_list() -> list:
|
||||
"""从plugin包中的函数结构中获取并处理成受GPT支持的格式"""
|
||||
if not host.__enable_content_functions__:
|
||||
return []
|
||||
|
||||
schemas = []
|
||||
|
||||
for func in host.__callable_functions__:
|
||||
if func['enabled']:
|
||||
fun_cp = func.copy()
|
||||
|
||||
del fun_cp['enabled']
|
||||
|
||||
schemas.append(fun_cp)
|
||||
|
||||
return schemas
|
||||
|
||||
def get_func(name: str) -> callable:
|
||||
if name not in host.__function_inst_map__:
|
||||
raise ContentFunctionNotFoundError("没有找到内容函数: {}".format(name))
|
||||
|
||||
return host.__function_inst_map__[name]
|
||||
|
||||
def get_func_schema(name: str) -> dict:
|
||||
for func in host.__callable_functions__:
|
||||
if func['name'] == name:
|
||||
return func
|
||||
raise ContentFunctionNotFoundError("没有找到内容函数: {}".format(name))
|
||||
|
||||
def execute_function(name: str, kwargs: dict) -> any:
|
||||
"""执行函数调用"""
|
||||
|
||||
logging.debug("executing function: name='{}', kwargs={}".format(name, kwargs))
|
||||
|
||||
func = get_func(name)
|
||||
return func(**kwargs)
|
||||
@@ -1,103 +0,0 @@
|
||||
# 此模块提供了维护api-key的各种功能
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from ..plugin import host as plugin_host
|
||||
from ..plugin import models as plugin_models
|
||||
|
||||
|
||||
class KeysManager:
|
||||
api_key = {}
|
||||
"""所有api-key"""
|
||||
|
||||
using_key = ""
|
||||
"""当前使用的api-key"""
|
||||
|
||||
alerted = []
|
||||
"""已提示过超额的key
|
||||
|
||||
记录在此以避免重复提示
|
||||
"""
|
||||
|
||||
exceeded = []
|
||||
"""已超额的key
|
||||
|
||||
供自动切换功能识别
|
||||
"""
|
||||
|
||||
def get_using_key(self):
|
||||
return self.using_key
|
||||
|
||||
def get_using_key_md5(self):
|
||||
return hashlib.md5(self.using_key.encode('utf-8')).hexdigest()
|
||||
|
||||
def __init__(self, api_key):
|
||||
|
||||
assert type(api_key) == dict
|
||||
self.api_key = api_key
|
||||
# 从usage中删除未加载的api-key的记录
|
||||
# 不删了,也许会运行时添加曾经有记录的api-key
|
||||
|
||||
self.auto_switch()
|
||||
|
||||
def auto_switch(self) -> tuple[bool, str]:
|
||||
"""尝试切换api-key
|
||||
|
||||
Returns:
|
||||
是否切换成功, 切换后的api-key的别名
|
||||
"""
|
||||
|
||||
index = 0
|
||||
|
||||
for key_name in self.api_key:
|
||||
if self.api_key[key_name] == self.using_key:
|
||||
break
|
||||
|
||||
index += 1
|
||||
|
||||
# 从当前key开始向后轮询
|
||||
start_index = index
|
||||
index += 1
|
||||
if index >= len(self.api_key):
|
||||
index = 0
|
||||
|
||||
while index != start_index:
|
||||
|
||||
key_name = list(self.api_key.keys())[index]
|
||||
|
||||
if self.api_key[key_name] not in self.exceeded:
|
||||
self.using_key = self.api_key[key_name]
|
||||
|
||||
logging.debug("使用api-key:" + key_name)
|
||||
|
||||
# 触发插件事件
|
||||
args = {
|
||||
"key_name": key_name,
|
||||
"key_list": self.api_key.keys()
|
||||
}
|
||||
_ = plugin_host.emit(plugin_models.KeySwitched, **args)
|
||||
|
||||
return True, key_name
|
||||
|
||||
index += 1
|
||||
if index >= len(self.api_key):
|
||||
index = 0
|
||||
|
||||
self.using_key = list(self.api_key.values())[start_index]
|
||||
logging.debug("使用api-key:" + list(self.api_key.keys())[start_index])
|
||||
|
||||
return False, list(self.api_key.keys())[start_index]
|
||||
|
||||
def add(self, key_name, key):
|
||||
self.api_key[key_name] = key
|
||||
|
||||
def set_current_exceeded(self):
|
||||
"""设置当前使用的api-key使用量超限"""
|
||||
self.exceeded.append(self.using_key)
|
||||
|
||||
def get_key_name(self, api_key):
|
||||
"""根据api-key获取其别名"""
|
||||
for key_name in self.api_key:
|
||||
if self.api_key[key_name] == api_key:
|
||||
return key_name
|
||||
return ""
|
||||
@@ -1,108 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import openai
|
||||
from openai.types import images_response
|
||||
|
||||
from ..openai import keymgr
|
||||
from ..utils import context
|
||||
from ..audit import gatherer
|
||||
from ..openai import modelmgr
|
||||
from ..openai.api import model as api_model
|
||||
from ..boot import app
|
||||
|
||||
|
||||
class OpenAIInteract:
|
||||
"""OpenAI 接口封装
|
||||
|
||||
将文字接口和图片接口封装供调用方使用
|
||||
"""
|
||||
|
||||
key_mgr: keymgr.KeysManager = None
|
||||
|
||||
audit_mgr: gatherer.DataGatherer = None
|
||||
|
||||
default_image_api_params = {
|
||||
"size": "256x256",
|
||||
}
|
||||
|
||||
client: openai.Client = None
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
|
||||
cfg= ap.cfg_mgr.data
|
||||
api_key = cfg['openai_config']['api_key']
|
||||
|
||||
self.key_mgr = keymgr.KeysManager(api_key)
|
||||
self.audit_mgr = gatherer.DataGatherer()
|
||||
|
||||
# 配置OpenAI proxy
|
||||
openai.proxies = None # 先重置,因为重载后可能需要清除proxy
|
||||
if "http_proxy" in cfg['openai_config'] and cfg['openai_config']["http_proxy"] is not None:
|
||||
openai.proxies = {
|
||||
"http": cfg['openai_config']["http_proxy"],
|
||||
"https": cfg['openai_config']["http_proxy"]
|
||||
}
|
||||
|
||||
# 配置openai api_base
|
||||
if "reverse_proxy" in cfg['openai_config'] and cfg['openai_config']["reverse_proxy"] is not None:
|
||||
logging.debug("设置反向代理: "+cfg['openai_config']['reverse_proxy'])
|
||||
openai.base_url = cfg['openai_config']["reverse_proxy"]
|
||||
|
||||
|
||||
self.client = openai.Client(
|
||||
api_key=self.key_mgr.get_using_key(),
|
||||
base_url=openai.base_url
|
||||
)
|
||||
|
||||
context.set_openai_manager(self)
|
||||
|
||||
def request_completion(self, messages: list):
|
||||
"""请求补全接口回复=
|
||||
"""
|
||||
# 选择接口请求类
|
||||
config = context.get_config_manager().data
|
||||
|
||||
request: api_model.RequestBase
|
||||
|
||||
model: str = config['completion_api_params']['model']
|
||||
|
||||
cp_parmas = config['completion_api_params'].copy()
|
||||
del cp_parmas['model']
|
||||
|
||||
request = modelmgr.select_request_cls(self.client, model, messages, cp_parmas)
|
||||
|
||||
# 请求接口
|
||||
for resp in request:
|
||||
|
||||
if resp['usage']['total_tokens'] > 0:
|
||||
self.audit_mgr.report_text_model_usage(
|
||||
model,
|
||||
resp['usage']['total_tokens']
|
||||
)
|
||||
|
||||
yield resp
|
||||
|
||||
def request_image(self, prompt) -> images_response.ImagesResponse:
|
||||
"""请求图片接口回复
|
||||
|
||||
Parameters:
|
||||
prompt (str): 提示语
|
||||
|
||||
Returns:
|
||||
dict: 响应
|
||||
"""
|
||||
config = context.get_config_manager().data
|
||||
params = config['image_api_params']
|
||||
|
||||
response = self.client.images.generate(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
**params
|
||||
)
|
||||
|
||||
self.audit_mgr.report_image_model_usage(params['size'])
|
||||
|
||||
return response
|
||||
|
||||
31
pkg/openai/requester/api.py
Normal file
31
pkg/openai/requester/api.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
from ..session import entities as session_entities
|
||||
|
||||
class LLMAPIRequester(metaclass=abc.ABCMeta):
|
||||
"""LLM API请求器
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def request(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
conversation: session_entities.Conversation,
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""请求
|
||||
"""
|
||||
raise NotImplementedError
|
||||
140
pkg/openai/requester/apis/chatcmpl.py
Normal file
140
pkg/openai/requester/apis/chatcmpl.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
import json
|
||||
|
||||
import openai
|
||||
import openai.types.chat.chat_completion as chat_completion
|
||||
|
||||
from .. import api
|
||||
from ....core import entities as core_entities
|
||||
from ... import entities as llm_entities
|
||||
from ...session import entities as session_entities
|
||||
|
||||
|
||||
class OpenAIChatCompletion(api.LLMAPIRequester):
|
||||
client: openai.AsyncClient
|
||||
|
||||
async def initialize(self):
|
||||
self.client = openai.AsyncClient(
|
||||
api_key="",
|
||||
base_url=self.ap.cfg_mgr.data["openai_config"]["reverse_proxy"],
|
||||
timeout=self.ap.cfg_mgr.data["process_message_timeout"],
|
||||
)
|
||||
|
||||
async def _req(
|
||||
self,
|
||||
args: dict,
|
||||
) -> chat_completion.ChatCompletion:
|
||||
self.ap.logger.debug(f"req chat_completion with args {args}")
|
||||
return await self.client.chat.completions.create(**args)
|
||||
|
||||
async def _make_msg(
|
||||
self,
|
||||
chat_completion: chat_completion.ChatCompletion,
|
||||
) -> llm_entities.Message:
|
||||
chatcmpl_message = chat_completion.choices[0].message.dict()
|
||||
|
||||
message = llm_entities.Message(**chatcmpl_message)
|
||||
|
||||
return message
|
||||
|
||||
async def _closure(
|
||||
self,
|
||||
req_messages: list[dict],
|
||||
conversation: session_entities.Conversation,
|
||||
user_text: str = None,
|
||||
function_ret: str = None,
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = conversation.use_model.token_mgr.get_token()
|
||||
|
||||
args = self.ap.cfg_mgr.data["completion_api_params"].copy()
|
||||
args["model"] = conversation.use_model.name
|
||||
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(conversation)
|
||||
# tools = [
|
||||
# {
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": "get_current_weather",
|
||||
# "description": "Get the current weather in a given location",
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "location": {
|
||||
# "type": "string",
|
||||
# "description": "The city and state, e.g. San Francisco, CA",
|
||||
# },
|
||||
# "unit": {
|
||||
# "type": "string",
|
||||
# "enum": ["celsius", "fahrenheit"],
|
||||
# },
|
||||
# },
|
||||
# "required": ["location"],
|
||||
# },
|
||||
# },
|
||||
# }
|
||||
# ]
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages
|
||||
args["messages"] = messages
|
||||
|
||||
# 发送请求
|
||||
resp = await self._req(args)
|
||||
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp)
|
||||
|
||||
return message
|
||||
|
||||
async def request(
|
||||
self, query: core_entities.Query, conversation: session_entities.Conversation
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""请求"""
|
||||
|
||||
pending_tool_calls = []
|
||||
|
||||
req_messages = [
|
||||
m.dict(exclude_none=True) for m in conversation.prompt.messages
|
||||
] + [m.dict(exclude_none=True) for m in conversation.messages]
|
||||
|
||||
# req_messages.append({"role": "user", "content": str(query.message_chain)})
|
||||
|
||||
msg = await self._closure(req_messages, conversation)
|
||||
|
||||
yield msg
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
|
||||
req_messages.append(msg.dict(exclude_none=True))
|
||||
|
||||
while pending_tool_calls:
|
||||
for tool_call in pending_tool_calls:
|
||||
func = tool_call.function
|
||||
|
||||
parameters = json.loads(func.arguments)
|
||||
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(
|
||||
query, func.name, parameters
|
||||
)
|
||||
|
||||
msg = llm_entities.Message(
|
||||
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
|
||||
)
|
||||
|
||||
yield msg
|
||||
|
||||
req_messages.append(msg.dict(exclude_none=True))
|
||||
|
||||
# 处理完所有调用,继续请求
|
||||
msg = await self._closure(req_messages, conversation)
|
||||
|
||||
yield msg
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
|
||||
req_messages.append(msg.dict(exclude_none=True))
|
||||
23
pkg/openai/requester/entities.py
Normal file
23
pkg/openai/requester/entities.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
|
||||
from . import api
|
||||
from . import token
|
||||
|
||||
|
||||
class LLMModelInfo(pydantic.BaseModel):
|
||||
"""模型"""
|
||||
|
||||
name: str
|
||||
|
||||
provider: str
|
||||
|
||||
token_mgr: token.TokenManager
|
||||
|
||||
requester: api.LLMAPIRequester
|
||||
|
||||
function_call_supported: typing.Optional[bool] = False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
41
pkg/openai/requester/modelmgr.py
Normal file
41
pkg/openai/requester/modelmgr.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import entities
|
||||
from ...core import app
|
||||
|
||||
from .apis import chatcmpl
|
||||
from . import token
|
||||
|
||||
|
||||
class ModelManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
model_list: list[entities.LLMModelInfo]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.model_list = []
|
||||
|
||||
async def initialize(self):
|
||||
openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap)
|
||||
await openai_chat_completion.initialize()
|
||||
openai_token_mgr = token.TokenManager(self.ap, list(self.ap.cfg_mgr.data['openai_config']['api_key'].values()))
|
||||
|
||||
self.model_list.append(
|
||||
entities.LLMModelInfo(
|
||||
name="gpt-3.5-turbo",
|
||||
provider="openai",
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
function_call_supported=True
|
||||
)
|
||||
)
|
||||
|
||||
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
|
||||
"""通过名称获取模型
|
||||
"""
|
||||
for model in self.model_list:
|
||||
if model.name == name:
|
||||
return model
|
||||
raise ValueError(f"Model {name} not found")
|
||||
25
pkg/openai/requester/token.py
Normal file
25
pkg/openai/requester/token.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class TokenManager():
|
||||
|
||||
provider: str
|
||||
|
||||
tokens: list[str]
|
||||
|
||||
using_token_index: typing.Optional[int] = 0
|
||||
|
||||
def __init__(self, provider: str, tokens: list[str]):
|
||||
self.provider = provider
|
||||
self.tokens = tokens
|
||||
self.using_token_index = 0
|
||||
|
||||
def get_token(self) -> str:
|
||||
return self.tokens[self.using_token_index]
|
||||
|
||||
def next_token(self):
|
||||
self.using_token_index = (self.using_token_index + 1) % len(self.tokens)
|
||||
@@ -1,504 +0,0 @@
|
||||
"""主线使用的会话管理模块
|
||||
|
||||
每个人、每个群单独一个session,session内部保留了对话的上下文,
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import json
|
||||
|
||||
from ..openai import manager as openai_manager
|
||||
from ..openai import modelmgr as openai_modelmgr
|
||||
from ..database import manager as database_manager
|
||||
from ..utils import context as context
|
||||
|
||||
from ..plugin import host as plugin_host
|
||||
from ..plugin import models as plugin_models
|
||||
|
||||
# 运行时保存的所有session
|
||||
sessions = {}
|
||||
|
||||
|
||||
class SessionOfflineStatus:
|
||||
ON_GOING = 'on_going'
|
||||
EXPLICITLY_CLOSED = 'explicitly_closed'
|
||||
|
||||
|
||||
# 从数据加载session
|
||||
def load_sessions():
|
||||
"""从数据库加载sessions"""
|
||||
|
||||
global sessions
|
||||
|
||||
db_inst = context.get_database_manager()
|
||||
|
||||
session_data = db_inst.load_valid_sessions()
|
||||
|
||||
for session_name in session_data:
|
||||
logging.debug('加载session: {}'.format(session_name))
|
||||
|
||||
temp_session = Session(session_name)
|
||||
temp_session.name = session_name
|
||||
temp_session.create_timestamp = session_data[session_name]['create_timestamp']
|
||||
temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp']
|
||||
|
||||
temp_session.prompt = json.loads(session_data[session_name]['prompt'])
|
||||
temp_session.token_counts = json.loads(session_data[session_name]['token_counts'])
|
||||
|
||||
temp_session.default_prompt = json.loads(session_data[session_name]['default_prompt']) if \
|
||||
session_data[session_name]['default_prompt'] else []
|
||||
|
||||
sessions[session_name] = temp_session
|
||||
|
||||
|
||||
# 获取指定名称的session,如果不存在则创建一个新的
|
||||
def get_session(session_name: str) -> 'Session':
|
||||
global sessions
|
||||
if session_name not in sessions:
|
||||
sessions[session_name] = Session(session_name)
|
||||
return sessions[session_name]
|
||||
|
||||
|
||||
def dump_session(session_name: str):
|
||||
global sessions
|
||||
if session_name in sessions:
|
||||
assert isinstance(sessions[session_name], Session)
|
||||
sessions[session_name].persistence()
|
||||
del sessions[session_name]
|
||||
|
||||
|
||||
# 通用的OpenAI API交互session
|
||||
# session内部保留了对话的上下文,
|
||||
# 收到用户消息后,将上下文提交给OpenAI API生成回复
|
||||
class Session:
|
||||
name = ''
|
||||
|
||||
prompt = []
|
||||
"""使用list来保存会话中的回合"""
|
||||
|
||||
default_prompt = []
|
||||
"""本session的默认prompt"""
|
||||
|
||||
create_timestamp = 0
|
||||
"""会话创建时间"""
|
||||
|
||||
last_interact_timestamp = 0
|
||||
"""上次交互(产生回复)时间"""
|
||||
|
||||
just_switched_to_exist_session = False
|
||||
|
||||
response_lock = None
|
||||
|
||||
# 加锁
|
||||
def acquire_response_lock(self):
|
||||
logging.debug('{},lock acquire,{}'.format(self.name, self.response_lock))
|
||||
self.response_lock.acquire()
|
||||
logging.debug('{},lock acquire successfully,{}'.format(self.name, self.response_lock))
|
||||
|
||||
# 释放锁
|
||||
def release_response_lock(self):
|
||||
if self.response_lock.locked():
|
||||
logging.debug('{},lock release,{}'.format(self.name, self.response_lock))
|
||||
self.response_lock.release()
|
||||
logging.debug('{},lock release successfully,{}'.format(self.name, self.response_lock))
|
||||
|
||||
# 从配置文件获取会话预设信息
|
||||
def get_default_prompt(self, use_default: str = None):
|
||||
import pkg.openai.dprompt as dprompt
|
||||
|
||||
if use_default is None:
|
||||
use_default = dprompt.mode_inst().get_using_name()
|
||||
|
||||
current_default_prompt, _ = dprompt.mode_inst().get_prompt(use_default)
|
||||
return current_default_prompt
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.create_timestamp = int(time.time())
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
self.prompt = []
|
||||
self.token_counts = []
|
||||
self.schedule()
|
||||
|
||||
self.response_lock = threading.Lock()
|
||||
|
||||
self.default_prompt = self.get_default_prompt()
|
||||
logging.debug("prompt is: {}".format(self.default_prompt))
|
||||
|
||||
# 设定检查session最后一次对话是否超过过期时间的计时器
|
||||
def schedule(self):
|
||||
threading.Thread(target=self.expire_check_timer_loop, args=(self.create_timestamp,)).start()
|
||||
|
||||
# 检查session是否已经过期
|
||||
def expire_check_timer_loop(self, create_timestamp: int):
|
||||
global sessions
|
||||
while True:
|
||||
time.sleep(60)
|
||||
|
||||
# 不是此session已更换,退出
|
||||
if self.create_timestamp != create_timestamp or self not in sessions.values():
|
||||
return
|
||||
|
||||
config = context.get_config_manager().data
|
||||
if int(time.time()) - self.last_interact_timestamp > config['session_expire_time']:
|
||||
logging.info('session {} 已过期'.format(self.name))
|
||||
|
||||
# 触发插件事件
|
||||
args = {
|
||||
'session_name': self.name,
|
||||
'session': self,
|
||||
'session_expire_time': config['session_expire_time']
|
||||
}
|
||||
event = plugin_host.emit(plugin_models.SessionExpired, **args)
|
||||
if event.is_prevented_default():
|
||||
return
|
||||
|
||||
self.reset(expired=True, schedule_new=False)
|
||||
|
||||
# 删除此session
|
||||
del sessions[self.name]
|
||||
return
|
||||
|
||||
# 请求回复
|
||||
# 这个函数是阻塞的
|
||||
def query(self, text: str=None) -> tuple[str, str, list[str]]:
|
||||
"""向session中添加一条消息,返回接口回复
|
||||
|
||||
Args:
|
||||
text (str): 用户消息
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: (接口回复, finish_reason, 已调用的函数列表)
|
||||
"""
|
||||
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
|
||||
# 触发插件事件
|
||||
if not self.prompt:
|
||||
args = {
|
||||
'session_name': self.name,
|
||||
'session': self,
|
||||
'default_prompt': self.default_prompt,
|
||||
}
|
||||
|
||||
event = plugin_host.emit(plugin_models.SessionFirstMessageReceived, **args)
|
||||
if event.is_prevented_default():
|
||||
return None, None, None
|
||||
|
||||
config = context.get_config_manager().data
|
||||
max_length = config['prompt_submit_length']
|
||||
|
||||
local_default_prompt = self.default_prompt.copy()
|
||||
local_prompt = self.prompt.copy()
|
||||
|
||||
# 触发PromptPreProcessing事件
|
||||
args = {
|
||||
'session_name': self.name,
|
||||
'default_prompt': self.default_prompt,
|
||||
'prompt': self.prompt,
|
||||
'text_message': text,
|
||||
}
|
||||
|
||||
event = plugin_host.emit(plugin_models.PromptPreProcessing, **args)
|
||||
|
||||
if event.get_return_value('default_prompt') is not None:
|
||||
local_default_prompt = event.get_return_value('default_prompt')
|
||||
|
||||
if event.get_return_value('prompt') is not None:
|
||||
local_prompt = event.get_return_value('prompt')
|
||||
|
||||
if event.get_return_value('text_message') is not None:
|
||||
text = event.get_return_value('text_message')
|
||||
|
||||
# 裁剪messages到合适长度
|
||||
prompts, _ = self.cut_out(text, max_length, local_default_prompt, local_prompt)
|
||||
|
||||
res_text = ""
|
||||
|
||||
pending_msgs = []
|
||||
|
||||
total_tokens = 0
|
||||
|
||||
finish_reason: str = ""
|
||||
|
||||
funcs = []
|
||||
|
||||
trace_func_calls = config['trace_function_calls']
|
||||
botmgr = context.get_qqbot_manager()
|
||||
|
||||
session_name_spt: list[str] = self.name.split("_")
|
||||
|
||||
pending_res_text = ""
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# TODO 对不起,我知道这样非常非常屎山,但我之后会重构的
|
||||
for resp in context.get_openai_manager().request_completion(prompts):
|
||||
|
||||
if pending_res_text != "":
|
||||
botmgr.adapter.send_message(
|
||||
session_name_spt[0],
|
||||
session_name_spt[1],
|
||||
pending_res_text
|
||||
)
|
||||
pending_res_text = ""
|
||||
|
||||
finish_reason = resp['choices'][0]['finish_reason']
|
||||
|
||||
if resp['choices'][0]['message']['role'] == "assistant" and resp['choices'][0]['message']['content'] != None: # 包含纯文本响应
|
||||
|
||||
if not trace_func_calls:
|
||||
res_text += resp['choices'][0]['message']['content']
|
||||
else:
|
||||
res_text = resp['choices'][0]['message']['content']
|
||||
pending_res_text = resp['choices'][0]['message']['content']
|
||||
|
||||
total_tokens += resp['usage']['total_tokens']
|
||||
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": resp['choices'][0]['message']['content']
|
||||
}
|
||||
|
||||
if 'function_call' in resp['choices'][0]['message']:
|
||||
msg['function_call'] = json.dumps(resp['choices'][0]['message']['function_call'])
|
||||
|
||||
pending_msgs.append(msg)
|
||||
|
||||
if resp['choices'][0]['message']['type'] == 'function_call':
|
||||
# self.prompt.append(
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": "function call: "+json.dumps(resp['choices'][0]['message']['function_call'])
|
||||
# }
|
||||
# )
|
||||
if trace_func_calls:
|
||||
botmgr.adapter.send_message(
|
||||
session_name_spt[0],
|
||||
session_name_spt[1],
|
||||
"调用函数 "+resp['choices'][0]['message']['function_call']['name'] + "..."
|
||||
)
|
||||
|
||||
total_tokens += resp['usage']['total_tokens']
|
||||
elif resp['choices'][0]['message']['type'] == 'function_return':
|
||||
# self.prompt.append(
|
||||
# {
|
||||
# "role": "function",
|
||||
# "name": resp['choices'][0]['message']['function_name'],
|
||||
# "content": json.dumps(resp['choices'][0]['message']['content'])
|
||||
# }
|
||||
# )
|
||||
|
||||
# total_tokens += resp['usage']['total_tokens']
|
||||
funcs.append(
|
||||
resp['choices'][0]['message']['function_name']
|
||||
)
|
||||
pass
|
||||
|
||||
# 向API请求补全
|
||||
# message, total_token = pkg.utils.context.get_openai_manager().request_completion(
|
||||
# prompts,
|
||||
# )
|
||||
|
||||
# 成功获取,处理回复
|
||||
# res_test = message
|
||||
res_ans = res_text.strip()
|
||||
|
||||
# 将此次对话的双方内容加入到prompt中
|
||||
# self.prompt.append({'role': 'user', 'content': text})
|
||||
# self.prompt.append({'role': 'assistant', 'content': res_ans})
|
||||
if text:
|
||||
self.prompt.append({'role': 'user', 'content': text})
|
||||
# 添加pending_msgs
|
||||
self.prompt += pending_msgs
|
||||
|
||||
# 向token_counts中添加本回合的token数量
|
||||
# self.token_counts.append(total_tokens-total_token_before_query)
|
||||
# logging.debug("本回合使用token: {}, session counts: {}".format(total_tokens-total_token_before_query, self.token_counts))
|
||||
|
||||
if self.just_switched_to_exist_session:
|
||||
self.just_switched_to_exist_session = False
|
||||
self.set_ongoing()
|
||||
|
||||
# 上报使用量数据
|
||||
session_type = session_name_spt[0]
|
||||
session_id = session_name_spt[1]
|
||||
|
||||
ability_provider = "QChatGPT.Text"
|
||||
usage = total_tokens
|
||||
model_name = context.get_config_manager().data['completion_api_params']['model']
|
||||
response_seconds = int(time.time() - start_time)
|
||||
retry_times = -1 # 暂不记录
|
||||
|
||||
context.get_center_v2_api().usage.post_query_record(
|
||||
session_type=session_type,
|
||||
session_id=session_id,
|
||||
query_ability_provider=ability_provider,
|
||||
usage=usage,
|
||||
model_name=model_name,
|
||||
response_seconds=response_seconds,
|
||||
retry_times=retry_times
|
||||
)
|
||||
|
||||
return res_ans if res_ans[0] != '\n' else res_ans[1:], finish_reason, funcs
|
||||
|
||||
# 删除上一回合并返回上一回合的问题
|
||||
def undo(self) -> str:
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
|
||||
# 删除最后两个消息
|
||||
if len(self.prompt) < 2:
|
||||
raise Exception('之前无对话,无法撤销')
|
||||
|
||||
question = self.prompt[-2]['content']
|
||||
self.prompt = self.prompt[:-2]
|
||||
self.token_counts = self.token_counts[:-1]
|
||||
|
||||
# 返回上一回合的问题
|
||||
return question
|
||||
|
||||
# 构建对话体
|
||||
def cut_out(self, msg: str, max_tokens: int, default_prompt: list, prompt: list) -> tuple[list, list]:
|
||||
"""将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens
|
||||
|
||||
:return: (新的prompt, 新的token_counts)
|
||||
"""
|
||||
|
||||
# 最终由三个部分组成
|
||||
# - default_prompt 情景预设固定值
|
||||
# - changable_prompts 可变部分, 此会话中的历史对话回合
|
||||
# - current_question 当前问题
|
||||
|
||||
# 包装目前的对话回合内容
|
||||
changable_prompts = []
|
||||
|
||||
use_model = context.get_config_manager().data['completion_api_params']['model']
|
||||
|
||||
ptr = len(prompt) - 1
|
||||
|
||||
# 直接从后向前扫描拼接,不管是否是整回合
|
||||
while ptr >= 0:
|
||||
if openai_modelmgr.count_tokens(prompt[ptr:ptr+1]+changable_prompts, use_model) > max_tokens:
|
||||
break
|
||||
|
||||
changable_prompts.insert(0, prompt[ptr])
|
||||
|
||||
ptr -= 1
|
||||
|
||||
# 将default_prompt和changable_prompts合并
|
||||
result_prompt = default_prompt + changable_prompts
|
||||
|
||||
# 添加当前问题
|
||||
if msg:
|
||||
result_prompt.append(
|
||||
{
|
||||
'role': 'user',
|
||||
'content': msg
|
||||
}
|
||||
)
|
||||
|
||||
logging.debug("cut_out: {}".format(json.dumps(result_prompt, ensure_ascii=False, indent=4)))
|
||||
|
||||
return result_prompt, openai_modelmgr.count_tokens(changable_prompts, use_model)
|
||||
|
||||
# 持久化session
|
||||
def persistence(self):
|
||||
if self.prompt == self.get_default_prompt():
|
||||
return
|
||||
|
||||
db_inst = context.get_database_manager()
|
||||
|
||||
name_spt = self.name.split('_')
|
||||
|
||||
subject_type = name_spt[0]
|
||||
subject_number = int(name_spt[1])
|
||||
|
||||
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
|
||||
json.dumps(self.prompt), json.dumps(self.default_prompt), json.dumps(self.token_counts))
|
||||
|
||||
# 重置session
|
||||
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None, persist: bool = False):
|
||||
if self.prompt:
|
||||
self.persistence()
|
||||
if explicit:
|
||||
# 触发插件事件
|
||||
args = {
|
||||
'session_name': self.name,
|
||||
'session': self
|
||||
}
|
||||
|
||||
# 此事件不支持阻止默认行为
|
||||
_ = plugin_host.emit(plugin_models.SessionExplicitReset, **args)
|
||||
|
||||
context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
|
||||
|
||||
if expired:
|
||||
context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
|
||||
|
||||
if not persist: # 不要求保持default prompt
|
||||
self.default_prompt = self.get_default_prompt(use_prompt)
|
||||
self.prompt = []
|
||||
self.token_counts = []
|
||||
self.create_timestamp = int(time.time())
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
self.just_switched_to_exist_session = False
|
||||
|
||||
# self.response_lock = threading.Lock()
|
||||
|
||||
if schedule_new:
|
||||
self.schedule()
|
||||
|
||||
# 将本session的数据库状态设置为on_going
|
||||
def set_ongoing(self):
|
||||
context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp)
|
||||
|
||||
# 切换到上一个session
|
||||
def last_session(self):
|
||||
last_one = context.get_database_manager().last_session(self.name, self.last_interact_timestamp)
|
||||
if last_one is None:
|
||||
return None
|
||||
else:
|
||||
self.persistence()
|
||||
|
||||
self.create_timestamp = last_one['create_timestamp']
|
||||
self.last_interact_timestamp = last_one['last_interact_timestamp']
|
||||
|
||||
self.prompt = json.loads(last_one['prompt'])
|
||||
self.token_counts = json.loads(last_one['token_counts'])
|
||||
|
||||
self.default_prompt = json.loads(last_one['default_prompt']) if last_one['default_prompt'] else []
|
||||
|
||||
self.just_switched_to_exist_session = True
|
||||
return self
|
||||
|
||||
# 切换到下一个session
|
||||
def next_session(self):
|
||||
next_one = context.get_database_manager().next_session(self.name, self.last_interact_timestamp)
|
||||
if next_one is None:
|
||||
return None
|
||||
else:
|
||||
self.persistence()
|
||||
|
||||
self.create_timestamp = next_one['create_timestamp']
|
||||
self.last_interact_timestamp = next_one['last_interact_timestamp']
|
||||
|
||||
self.prompt = json.loads(next_one['prompt'])
|
||||
self.token_counts = json.loads(next_one['token_counts'])
|
||||
|
||||
self.default_prompt = json.loads(next_one['default_prompt']) if next_one['default_prompt'] else []
|
||||
|
||||
self.just_switched_to_exist_session = True
|
||||
return self
|
||||
|
||||
def list_history(self, capacity: int = 10, page: int = 0):
|
||||
return context.get_database_manager().list_history(self.name, capacity, page)
|
||||
|
||||
def delete_history(self, index: int) -> bool:
|
||||
return context.get_database_manager().delete_history(self.name, index)
|
||||
|
||||
def delete_all_history(self) -> bool:
|
||||
return context.get_database_manager().delete_all_history(self.name)
|
||||
|
||||
def draw_image(self, prompt: str):
|
||||
return context.get_openai_manager().request_image(prompt)
|
||||
53
pkg/openai/session/entities.py
Normal file
53
pkg/openai/session/entities.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
|
||||
from ..sysprompt import entities as sysprompt_entities
|
||||
from .. import entities as llm_entities
|
||||
from ..requester import entities
|
||||
from ...core import entities as core_entities
|
||||
from ..tools import entities as tools_entities
|
||||
|
||||
|
||||
class Conversation(pydantic.BaseModel):
|
||||
"""对话"""
|
||||
|
||||
prompt: sysprompt_entities.Prompt
|
||||
|
||||
messages: list[llm_entities.Message]
|
||||
|
||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
use_model: entities.LLMModelInfo
|
||||
|
||||
use_funcs: typing.Optional[list[tools_entities.LLMFunction]]
|
||||
|
||||
|
||||
class Session(pydantic.BaseModel):
|
||||
"""会话"""
|
||||
launcher_type: core_entities.LauncherTypes
|
||||
|
||||
launcher_id: int
|
||||
|
||||
sender_id: typing.Optional[int] = 0
|
||||
|
||||
use_prompt_name: typing.Optional[str] = 'default'
|
||||
|
||||
using_conversation: typing.Optional[Conversation] = None
|
||||
|
||||
conversations: typing.Optional[list[Conversation]] = []
|
||||
|
||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
semaphore: typing.Optional[asyncio.Semaphore] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
51
pkg/openai/session/sessionmgr.py
Normal file
51
pkg/openai/session/sessionmgr.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from . import entities
|
||||
|
||||
|
||||
class SessionManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
session_list: list[entities.Session]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.session_list = []
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def get_session(self, query: core_entities.Query) -> entities.Session:
|
||||
"""获取会话
|
||||
"""
|
||||
for session in self.session_list:
|
||||
if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id:
|
||||
return session
|
||||
|
||||
session = entities.Session(
|
||||
launcher_type=query.launcher_type,
|
||||
launcher_id=query.launcher_id,
|
||||
semaphore=asyncio.Semaphore(1) if self.ap.cfg_mgr.data['wait_last_done'] else asyncio.Semaphore(10000),
|
||||
)
|
||||
self.session_list.append(session)
|
||||
return session
|
||||
|
||||
async def get_conversation(self, session: entities.Session) -> entities.Conversation:
|
||||
if not session.conversations:
|
||||
session.conversations = []
|
||||
|
||||
if session.using_conversation is None:
|
||||
conversation = entities.Conversation(
|
||||
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
|
||||
messages=[],
|
||||
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.cfg_mgr.data['completion_api_params']['model']),
|
||||
use_funcs=await self.ap.tool_mgr.get_all_functions(),
|
||||
)
|
||||
session.conversations.append(conversation)
|
||||
session.using_conversation = conversation
|
||||
|
||||
return session.using_conversation
|
||||
14
pkg/openai/sysprompt/entities.py
Normal file
14
pkg/openai/sysprompt/entities.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import pydantic
|
||||
|
||||
from ...openai import entities
|
||||
|
||||
|
||||
class Prompt(pydantic.BaseModel):
|
||||
"""供AI使用的Prompt"""
|
||||
|
||||
name: str
|
||||
|
||||
messages: list[entities.Message]
|
||||
32
pkg/openai/sysprompt/loader.py
Normal file
32
pkg/openai/sysprompt/loader.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
|
||||
class PromptLoader(metaclass=abc.ABCMeta):
|
||||
"""Prompt加载器抽象类
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
prompts: list[entities.Prompt]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.prompts = []
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_prompts(self) -> list[entities.Prompt]:
|
||||
"""获取Prompt列表
|
||||
"""
|
||||
return self.prompts
|
||||
38
pkg/openai/sysprompt/loaders/scenario.py
Normal file
38
pkg/openai/sysprompt/loaders/scenario.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from .. import loader
|
||||
from .. import entities
|
||||
from ....openai import entities as llm_entities
|
||||
|
||||
|
||||
class ScenarioPromptLoader(loader.PromptLoader):
|
||||
"""加载scenario目录下的json"""
|
||||
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
for file in os.listdir("scenarios"):
|
||||
with open("scenarios/{}".format(file), "r", encoding="utf-8") as f:
|
||||
file_str = f.read()
|
||||
file_name = file.split(".")[0]
|
||||
file_json = json.loads(file_str)
|
||||
messages = []
|
||||
for msg in file_json["prompt"]:
|
||||
role = 'system'
|
||||
if "role" in msg:
|
||||
role = msg['role']
|
||||
messages.append(
|
||||
llm_entities.Message(
|
||||
role=role,
|
||||
content=msg['content'],
|
||||
)
|
||||
)
|
||||
prompt = entities.Prompt(
|
||||
name=file_name,
|
||||
messages=messages
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
|
||||
42
pkg/openai/sysprompt/loaders/single.py
Normal file
42
pkg/openai/sysprompt/loaders/single.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
|
||||
from .. import loader
|
||||
from .. import entities
|
||||
from ....openai import entities as llm_entities
|
||||
|
||||
|
||||
class SingleSystemPromptLoader(loader.PromptLoader):
|
||||
"""配置文件中的单条system prompt的prompt加载器
|
||||
"""
|
||||
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
|
||||
for name, cnt in self.ap.cfg_mgr.data['default_prompt'].items():
|
||||
prompt = entities.Prompt(
|
||||
name=name,
|
||||
messages=[
|
||||
llm_entities.Message(
|
||||
role='system',
|
||||
content=cnt
|
||||
)
|
||||
]
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
|
||||
for file in os.listdir("prompts"):
|
||||
with open("prompts/{}".format(file), "r", encoding="utf-8") as f:
|
||||
file_str = f.read()
|
||||
file_name = file.split(".")[0]
|
||||
prompt = entities.Prompt(
|
||||
name=file_name,
|
||||
messages=[
|
||||
llm_entities.Message(
|
||||
role='system',
|
||||
content=file_str
|
||||
)
|
||||
]
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
50
pkg/openai/sysprompt/sysprompt.py
Normal file
50
pkg/openai/sysprompt/sysprompt.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ...core import app
|
||||
from . import loader
|
||||
from .loaders import single, scenario
|
||||
|
||||
|
||||
class PromptManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
loader_inst: loader.PromptLoader
|
||||
|
||||
default_prompt: str = 'default'
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
loader_map = {
|
||||
"normal": single.SingleSystemPromptLoader,
|
||||
"full_scenario": scenario.ScenarioPromptLoader
|
||||
}
|
||||
|
||||
loader_cls = loader_map[self.ap.cfg_mgr.data['preset_mode']]
|
||||
|
||||
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)
|
||||
|
||||
await self.loader_inst.initialize()
|
||||
await self.loader_inst.load()
|
||||
|
||||
def get_all_prompts(self) -> list[loader.entities.Prompt]:
|
||||
"""获取所有Prompt
|
||||
"""
|
||||
return self.loader_inst.get_prompts()
|
||||
|
||||
async def get_prompt(self, name: str) -> loader.entities.Prompt:
|
||||
"""获取Prompt
|
||||
"""
|
||||
for prompt in self.get_all_prompts():
|
||||
if prompt.name == name:
|
||||
return prompt
|
||||
|
||||
async def get_prompt_by_prefix(self, prefix: str) -> loader.entities.Prompt:
|
||||
"""通过前缀获取Prompt
|
||||
"""
|
||||
for prompt in self.get_all_prompts():
|
||||
if prompt.name.startswith(prefix):
|
||||
return prompt
|
||||
35
pkg/openai/tools/entities.py
Normal file
35
pkg/openai/tools/entities.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
import asyncio
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class LLMFunction(pydantic.BaseModel):
|
||||
"""函数"""
|
||||
|
||||
name: str
|
||||
"""函数名"""
|
||||
|
||||
human_desc: str
|
||||
|
||||
description: str
|
||||
"""给LLM识别的函数描述"""
|
||||
|
||||
enable: typing.Optional[bool] = True
|
||||
|
||||
parameters: dict
|
||||
|
||||
func: typing.Callable
|
||||
"""供调用的python异步方法
|
||||
|
||||
此异步方法第一个参数接收当前请求的query对象,可以从其中取出session等信息。
|
||||
query参数不在parameters中,但在调用时会自动传入。
|
||||
但在当前版本中,插件提供的内容函数都是同步的,且均为请求无关的,故在此版本的实现(以及考虑了向后兼容性的版本)中,
|
||||
对插件的内容函数进行封装并存到这里来。
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
99
pkg/openai/tools/toolmgr.py
Normal file
99
pkg/openai/tools/toolmgr.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from . import entities
|
||||
from ..session import entities as session_entities
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""LLM工具管理器
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
all_functions: list[entities.LLMFunction]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.all_functions = []
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
def register_legacy_function(self, name: str, description: str, parameters: dict, func: callable):
|
||||
"""注册函数
|
||||
"""
|
||||
async def wrapper(query, **kwargs):
|
||||
return func(**kwargs)
|
||||
function = entities.LLMFunction(
|
||||
name=name,
|
||||
description=description,
|
||||
human_desc='',
|
||||
enable=True,
|
||||
parameters=parameters,
|
||||
func=wrapper
|
||||
)
|
||||
self.all_functions.append(function)
|
||||
|
||||
async def register_function(self, function: entities.LLMFunction):
|
||||
"""添加函数
|
||||
"""
|
||||
self.all_functions.append(function)
|
||||
|
||||
async def get_function(self, name: str) -> entities.LLMFunction:
|
||||
"""获取函数
|
||||
"""
|
||||
for function in self.all_functions:
|
||||
if function.name == name:
|
||||
return function
|
||||
return None
|
||||
|
||||
async def get_all_functions(self) -> list[entities.LLMFunction]:
|
||||
"""获取所有函数
|
||||
"""
|
||||
return self.all_functions
|
||||
|
||||
async def generate_tools_for_openai(self, conversation: session_entities.Conversation) -> str:
|
||||
"""生成函数列表
|
||||
"""
|
||||
tools = []
|
||||
|
||||
for function in conversation.use_funcs:
|
||||
if function.enable:
|
||||
function_schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function.name,
|
||||
"description": function.description,
|
||||
"parameters": function.parameters
|
||||
}
|
||||
}
|
||||
tools.append(function_schema)
|
||||
|
||||
return tools
|
||||
|
||||
async def execute_func_call(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
name: str,
|
||||
parameters: dict
|
||||
) -> typing.Any:
|
||||
"""执行函数调用
|
||||
"""
|
||||
|
||||
# return "i'm not sure for the args "+str(parameters)
|
||||
|
||||
function = await self.get_function(name)
|
||||
if function is None:
|
||||
return None
|
||||
|
||||
parameters = parameters.copy()
|
||||
|
||||
parameters = {
|
||||
"query": query,
|
||||
**parameters
|
||||
}
|
||||
|
||||
return await function.func(**parameters)
|
||||
@@ -1,70 +1,76 @@
|
||||
# 处理对会话的禁用配置
|
||||
# 过去的 banlist
|
||||
from __future__ import annotations
|
||||
import re
|
||||
|
||||
from ...boot import app
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
|
||||
class SessionBanManager:
|
||||
|
||||
ap: app.Application = None
|
||||
@stage.stage_class('BanSessionCheckStage')
|
||||
class BanSessionCheckStage(stage.PipelineStage):
|
||||
|
||||
banlist_mgr: cfg_mgr.ConfigManager
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
self.banlist_mgr = await cfg_mgr.load_python_module_config(
|
||||
"banlist.py",
|
||||
"res/templates/banlist-template.py"
|
||||
)
|
||||
|
||||
async def is_banned(
|
||||
self, launcher_type: str, launcher_id: int, sender_id: int
|
||||
) -> bool:
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
|
||||
if not self.banlist_mgr.data['enable']:
|
||||
return False
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
result = False
|
||||
|
||||
if launcher_type == 'group':
|
||||
if query.launcher_type == 'group':
|
||||
if not self.banlist_mgr.data['enable_group']: # 未启用群聊响应
|
||||
result = True
|
||||
# 检查是否显式声明发起人QQ要被person忽略
|
||||
elif sender_id in self.banlist_mgr.data['person']:
|
||||
elif query.sender_id in self.banlist_mgr.data['person']:
|
||||
result = True
|
||||
else:
|
||||
for group_rule in self.banlist_mgr.data['group']:
|
||||
if type(group_rule) == int:
|
||||
if group_rule == launcher_id:
|
||||
if group_rule == query.launcher_id:
|
||||
result = True
|
||||
elif type(group_rule) == str:
|
||||
if group_rule.startswith('!'):
|
||||
reg_str = group_rule[1:]
|
||||
if re.match(reg_str, str(launcher_id)):
|
||||
if re.match(reg_str, str(query.launcher_id)):
|
||||
result = False
|
||||
break
|
||||
else:
|
||||
if re.match(group_rule, str(launcher_id)):
|
||||
if re.match(group_rule, str(query.launcher_id)):
|
||||
result = True
|
||||
elif launcher_type == 'person':
|
||||
elif query.launcher_type == 'person':
|
||||
if not self.banlist_mgr.data['enable_private']:
|
||||
result = True
|
||||
else:
|
||||
for person_rule in self.banlist_mgr.data['person']:
|
||||
if type(person_rule) == int:
|
||||
if person_rule == launcher_id:
|
||||
if person_rule == query.launcher_id:
|
||||
result = True
|
||||
elif type(person_rule) == str:
|
||||
if person_rule.startswith('!'):
|
||||
reg_str = person_rule[1:]
|
||||
if re.match(reg_str, str(launcher_id)):
|
||||
if re.match(reg_str, str(query.launcher_id)):
|
||||
result = False
|
||||
break
|
||||
else:
|
||||
if re.match(person_rule, str(launcher_id)):
|
||||
if re.match(person_rule, str(query.launcher_id)):
|
||||
result = True
|
||||
return result
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE if not result else entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
debug_notice=f'根据禁用列表忽略消息: {query.launcher_type}_{query.launcher_id}' if result else ''
|
||||
)
|
||||
128
pkg/pipeline/cntfilter/cntfilter.py
Normal file
128
pkg/pipeline/cntfilter/cntfilter.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import mirai
|
||||
|
||||
from ...core import app
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from . import filter, entities as filter_entities
|
||||
from .filters import cntignore, banwords, baiduexamine
|
||||
|
||||
|
||||
@stage.stage_class('PostContentFilterStage')
|
||||
@stage.stage_class('PreContentFilterStage')
|
||||
class ContentFilterStage(stage.PipelineStage):
|
||||
|
||||
filter_chain: list[filter.ContentFilter]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.filter_chain = []
|
||||
super().__init__(ap)
|
||||
|
||||
async def initialize(self):
|
||||
self.filter_chain.append(cntignore.ContentIgnore(self.ap))
|
||||
|
||||
if self.ap.cfg_mgr.data['sensitive_word_filter']:
|
||||
self.filter_chain.append(banwords.BanWordFilter(self.ap))
|
||||
|
||||
if self.ap.cfg_mgr.data['baidu_check']:
|
||||
self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap))
|
||||
|
||||
for filter in self.filter_chain:
|
||||
await filter.initialize()
|
||||
|
||||
async def _pre_process(
|
||||
self,
|
||||
message: str,
|
||||
query: core_entities.Query,
|
||||
) -> entities.StageProcessResult:
|
||||
"""请求llm前处理消息
|
||||
只要有一个不通过就不放行,只放行 PASS 的消息
|
||||
"""
|
||||
if not self.ap.cfg_mgr.data['income_msg_check']:
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
for filter in self.filter_chain:
|
||||
if filter_entities.EnableStage.PRE in filter.enable_stages:
|
||||
result = await filter.process(message)
|
||||
|
||||
if result.level in [
|
||||
filter_entities.ResultLevel.BLOCK,
|
||||
filter_entities.ResultLevel.MASKED
|
||||
]:
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
user_notice=result.user_notice,
|
||||
console_notice=result.console_notice
|
||||
)
|
||||
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
|
||||
message = result.replacement
|
||||
|
||||
query.message_chain = mirai.MessageChain(
|
||||
mirai.Plain(message)
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
async def _post_process(
|
||||
self,
|
||||
message: str,
|
||||
query: core_entities.Query,
|
||||
) -> entities.StageProcessResult:
|
||||
"""请求llm后处理响应
|
||||
只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter
|
||||
"""
|
||||
for filter in self.filter_chain:
|
||||
if filter_entities.EnableStage.POST in filter.enable_stages:
|
||||
result = await filter.process(message)
|
||||
|
||||
if result.level == filter_entities.ResultLevel.BLOCK:
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
user_notice=result.user_notice,
|
||||
console_notice=result.console_notice
|
||||
)
|
||||
elif result.level in [
|
||||
filter_entities.ResultLevel.PASS,
|
||||
filter_entities.ResultLevel.MASKED
|
||||
]:
|
||||
message = result.replacement
|
||||
|
||||
query.message_chain = mirai.MessageChain(
|
||||
mirai.Plain(message)
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
"""处理
|
||||
"""
|
||||
if stage_inst_name == 'PreContentFilterStage':
|
||||
return await self._pre_process(
|
||||
str(query.message_chain).strip(),
|
||||
query
|
||||
)
|
||||
elif stage_inst_name == 'PostContentFilterStage':
|
||||
return await self._post_process(
|
||||
str(query.message_chain).strip(),
|
||||
query
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')
|
||||
@@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
|
||||
from ...boot import app
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
|
||||
0
pkg/pipeline/cntfilter/filters/__init__.py
Normal file
0
pkg/pipeline/cntfilter/filters/__init__.py
Normal file
38
pkg/pipeline/entities.py
Normal file
38
pkg/pipeline/entities.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
import mirai
|
||||
import mirai.models.message as mirai_message
|
||||
|
||||
from ..core import entities
|
||||
|
||||
|
||||
class ResultType(enum.Enum):
|
||||
|
||||
CONTINUE = enum.auto()
|
||||
"""继续流水线"""
|
||||
|
||||
INTERRUPT = enum.auto()
|
||||
"""中断流水线"""
|
||||
|
||||
|
||||
class StageProcessResult(pydantic.BaseModel):
|
||||
|
||||
result_type: ResultType
|
||||
|
||||
new_query: entities.Query
|
||||
|
||||
user_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = []
|
||||
"""只要设置了就会发送给用户"""
|
||||
|
||||
admin_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = []
|
||||
"""只要设置了就会发送给管理员"""
|
||||
|
||||
console_notice: typing.Optional[str] = ''
|
||||
"""只要设置了就会输出到控制台"""
|
||||
|
||||
debug_notice: typing.Optional[str] = ''
|
||||
|
||||
0
pkg/pipeline/longtext/__init__.py
Normal file
0
pkg/pipeline/longtext/__init__.py
Normal file
@@ -3,22 +3,21 @@ import os
|
||||
import traceback
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from mirai.models.message import MessageComponent, Plain
|
||||
from mirai.models.message import MessageComponent, Plain, MessageChain
|
||||
|
||||
from ...boot import app
|
||||
from ...core import app
|
||||
from . import strategy
|
||||
from .strategies import image, forward
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
|
||||
class LongTextProcessor:
|
||||
|
||||
ap: app.Application
|
||||
@stage.stage_class("LongTextProcessStage")
|
||||
class LongTextProcessStage(stage.PipelineStage):
|
||||
|
||||
strategy_impl: strategy.LongTextStrategy
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
config = self.ap.cfg_mgr.data
|
||||
if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image':
|
||||
@@ -48,9 +47,11 @@ class LongTextProcessor:
|
||||
elif self.ap.cfg_mgr.data['blob_message_strategy'] == 'forward':
|
||||
self.strategy_impl = forward.ForwardComponentStrategy(self.ap)
|
||||
await self.strategy_impl.initialize()
|
||||
|
||||
async def check_and_process(self, message: str) -> list[MessageComponent]:
|
||||
if len(message) > self.ap.cfg_mgr.data['blob_message_threshold']:
|
||||
return await self.strategy_impl.process(message)
|
||||
else:
|
||||
return [Plain(message)]
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
if len(str(query.resp_message_chain)) > self.ap.cfg_mgr.data['blob_message_threshold']:
|
||||
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain)))
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
0
pkg/pipeline/longtext/strategies/__init__.py
Normal file
0
pkg/pipeline/longtext/strategies/__init__.py
Normal file
@@ -5,7 +5,7 @@ import typing
|
||||
import mirai
|
||||
from mirai.models.message import MessageComponent
|
||||
|
||||
from ...boot import app
|
||||
from ...core import app
|
||||
|
||||
|
||||
class LongTextStrategy(metaclass=abc.ABCMeta):
|
||||
0
pkg/pipeline/process/__init__.py
Normal file
0
pkg/pipeline/process/__init__.py
Normal file
25
pkg/pipeline/process/handler.py
Normal file
25
pkg/pipeline/process/handler.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
from .. import entities
|
||||
|
||||
|
||||
class MessageHandler(metaclass=abc.ABCMeta):
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def handle(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
) -> entities.StageProcessResult:
|
||||
raise NotImplementedError
|
||||
0
pkg/pipeline/process/handlers/__init__.py
Normal file
0
pkg/pipeline/process/handlers/__init__.py
Normal file
44
pkg/pipeline/process/handlers/chat.py
Normal file
44
pkg/pipeline/process/handlers/chat.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import mirai
|
||||
|
||||
from .. import handler
|
||||
from ... import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....openai import entities as llm_entities
|
||||
|
||||
|
||||
class ChatMessageHandler(handler.MessageHandler):
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||
"""处理
|
||||
"""
|
||||
# 取session
|
||||
# 取conversation
|
||||
# 调API
|
||||
# 生成器
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
conversation = await self.ap.sess_mgr.get_conversation(session)
|
||||
|
||||
conversation.messages.append(
|
||||
llm_entities.Message(
|
||||
role="user",
|
||||
content=str(query.message_chain)
|
||||
)
|
||||
)
|
||||
|
||||
async for result in conversation.use_model.requester.request(query, conversation):
|
||||
conversation.messages.append(result)
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(str(result))])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
50
pkg/pipeline/process/handlers/command.py
Normal file
50
pkg/pipeline/process/handlers/command.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
import typing
|
||||
|
||||
import mirai
|
||||
|
||||
from .. import handler
|
||||
from ... import entities
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
class CommandHandler(handler.MessageHandler):
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||
"""处理
|
||||
"""
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
command_text = str(query.message_chain).strip()[1:]
|
||||
|
||||
async for ret in self.ap.cmd_mgr.execute(
|
||||
command_text=command_text,
|
||||
query=query,
|
||||
session=session
|
||||
):
|
||||
if ret.error is not None:
|
||||
query.resp_message_chain = mirai.MessageChain([
|
||||
mirai.Plain(str(ret.error))
|
||||
])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
elif ret.text is not None:
|
||||
query.resp_message_chain = mirai.MessageChain([
|
||||
mirai.Plain(ret.text)
|
||||
])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
)
|
||||
38
pkg/pipeline/process/process.py
Normal file
38
pkg/pipeline/process/process.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from . import handler
|
||||
from .handlers import chat, command
|
||||
from .. import entities
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
|
||||
@stage.stage_class("MessageProcessor")
|
||||
class Processor(stage.PipelineStage):
|
||||
|
||||
cmd_handler: handler.MessageHandler
|
||||
|
||||
chat_handler: handler.MessageHandler
|
||||
|
||||
async def initialize(self):
|
||||
self.cmd_handler = command.CommandHandler(self.ap)
|
||||
self.chat_handler = chat.ChatMessageHandler(self.ap)
|
||||
|
||||
await self.cmd_handler.initialize()
|
||||
await self.chat_handler.initialize()
|
||||
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
stage_inst_name: str,
|
||||
) -> entities.StageProcessResult:
|
||||
"""处理
|
||||
"""
|
||||
message_text = str(query.message_chain).strip()
|
||||
|
||||
if message_text.startswith('!') or message_text.startswith('!'):
|
||||
return self.cmd_handler.handle(query)
|
||||
else:
|
||||
return self.chat_handler.handle(query)
|
||||
0
pkg/pipeline/respback/__init__.py
Normal file
0
pkg/pipeline/respback/__init__.py
Normal file
29
pkg/pipeline/respback/respback.py
Normal file
29
pkg/pipeline/respback/respback.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import mirai
|
||||
|
||||
from ...core import app
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
|
||||
@stage.stage_class("SendResponseBackStage")
|
||||
class SendResponseBackStage(stage.PipelineStage):
|
||||
"""发送响应消息
|
||||
"""
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理
|
||||
"""
|
||||
|
||||
await self.ap.im_mgr.send(
|
||||
query.message_event,
|
||||
query.resp_message_chain
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
0
pkg/pipeline/resprule/__init__.py
Normal file
0
pkg/pipeline/resprule/__init__.py
Normal file
62
pkg/pipeline/resprule/resprule.py
Normal file
62
pkg/pipeline/resprule/resprule.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import mirai
|
||||
|
||||
from ...core import app
|
||||
from . import entities as rule_entities, rule
|
||||
from .rules import atbot, prefix, regexp, random
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
|
||||
@stage.stage_class("GroupRespondRuleCheckStage")
|
||||
class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
"""群组响应规则检查器
|
||||
"""
|
||||
|
||||
rule_matchers: list[rule.GroupRespondRule]
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化检查器
|
||||
"""
|
||||
self.rule_matchers = [
|
||||
atbot.AtBotRule(self.ap),
|
||||
prefix.PrefixRule(self.ap),
|
||||
regexp.RegExpRule(self.ap),
|
||||
random.RandomRespRule(self.ap),
|
||||
]
|
||||
|
||||
for rule_matcher in self.rule_matchers:
|
||||
await rule_matcher.initialize()
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
|
||||
if query.launcher_type != 'group':
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
rules = self.ap.cfg_mgr.data['response_rules']
|
||||
|
||||
use_rule = rules['default']
|
||||
|
||||
if str(query.launcher_id) in use_rule:
|
||||
use_rule = use_rule[str(query.launcher_id)]
|
||||
|
||||
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
|
||||
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule)
|
||||
if res.matching:
|
||||
query.message_chain = res.replacement
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query,
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
)
|
||||
@@ -3,7 +3,7 @@ import abc
|
||||
|
||||
import mirai
|
||||
|
||||
from ...boot import app
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
|
||||
0
pkg/pipeline/resprule/rules/__init__.py
Normal file
0
pkg/pipeline/resprule/rules/__init__.py
Normal file
47
pkg/pipeline/stage.py
Normal file
47
pkg/pipeline/stage.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from . import entities
|
||||
|
||||
|
||||
_stage_classes: dict[str, PipelineStage] = {}
|
||||
|
||||
|
||||
def stage_class(name: str):
|
||||
|
||||
def decorator(cls):
|
||||
_stage_classes[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class PipelineStage(metaclass=abc.ABCMeta):
|
||||
"""流水线阶段
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
stage_inst_name: str,
|
||||
) -> typing.Union[
|
||||
entities.StageProcessResult,
|
||||
typing.AsyncGenerator[entities.StageProcessResult, None],
|
||||
]:
|
||||
"""处理
|
||||
"""
|
||||
raise NotImplementedError
|
||||
63
pkg/pipeline/stagemgr.py
Normal file
63
pkg/pipeline/stagemgr.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pydantic
|
||||
|
||||
from ..core import app
|
||||
from . import stage
|
||||
from .resprule import resprule
|
||||
from .bansess import bansess
|
||||
from .cntfilter import cntfilter
|
||||
from .process import process
|
||||
from .longtext import longtext
|
||||
from .respback import respback
|
||||
|
||||
|
||||
stage_order = [
|
||||
"GroupRespondRuleCheckStage",
|
||||
"BanSessionCheckStage",
|
||||
"PreContentFilterStage",
|
||||
"MessageProcessor",
|
||||
"PostContentFilterStage",
|
||||
"LongTextProcessStage",
|
||||
"SendResponseBackStage",
|
||||
]
|
||||
|
||||
|
||||
class StageInstContainer():
|
||||
"""阶段实例容器
|
||||
"""
|
||||
|
||||
inst_name: str
|
||||
|
||||
inst: stage.PipelineStage
|
||||
|
||||
def __init__(self, inst_name: str, inst: stage.PipelineStage):
|
||||
self.inst_name = inst_name
|
||||
self.inst = inst
|
||||
|
||||
|
||||
class StageManager:
|
||||
ap: app.Application
|
||||
|
||||
stage_containers: list[StageInstContainer]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
self.stage_containers = []
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化
|
||||
"""
|
||||
|
||||
for name, cls in stage._stage_classes.items():
|
||||
self.stage_containers.append(StageInstContainer(
|
||||
inst_name=name,
|
||||
inst=cls(self.ap)
|
||||
))
|
||||
|
||||
for stage_containers in self.stage_containers:
|
||||
await stage_containers.inst.initialize()
|
||||
|
||||
# 按照 stage_order 排序
|
||||
self.stage_containers.sort(key=lambda x: stage_order.index(x.inst_name))
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user