Compare commits

...

54 Commits

Author SHA1 Message Date
RockChinQ
1fb69311b0 chore: release v3.1.0 2024-03-22 17:17:16 +08:00
Junyan Qin
995d1f61d2 Merge pull request #735 from RockChinQ/feat/plugin-api
Feat: 插件异步 API
2024-03-22 17:10:06 +08:00
RockChinQ
80258e9182 perf: 修改platform_mgr名称 2024-03-22 17:09:43 +08:00
RockChinQ
bd6a32e08e doc: 为可扩展组件添加注释 2024-03-22 16:41:46 +08:00
RockChinQ
5f138de75b doc: 完善query对象的注释 2024-03-22 11:05:58 +08:00
RockChinQ
d0b0f2209a fix: chat处理过程的插件返回值目标错误 2024-03-20 23:32:28 +08:00
RockChinQ
0752698c1d chore: 完善plugin对外对象的注释 2024-03-20 18:43:52 +08:00
RockChinQ
9855c6b8f5 feat: 新的引入路径 2024-03-20 15:48:11 +08:00
RockChinQ
52a7c25540 feat: 异步风格插件方法注册器 2024-03-20 15:09:47 +08:00
RockChinQ
fa823de6b0 perf: 初始化config对象时支持传递dict作为模板 2024-03-20 14:20:56 +08:00
RockChinQ
f53070d8b6 feat: 插件加载阶段前置 (#681) 2024-03-19 22:48:02 +08:00
Junyan Qin
7677672691 Merge pull request #734 from RockChinQ/feat/moonshot
Feat: 添加对 moonshot 模型的支持
2024-03-19 22:41:40 +08:00
RockChinQ
dead8fa168 feat: 添加对 moonshot 模型的支持 2024-03-19 22:39:45 +08:00
RockChinQ
c6347bea45 fix: full-scenario 命名和目录名错误问题 (#731) 2024-03-18 21:05:54 +08:00
RockChinQ
32bd194bfc chore: anthropic 的配置补全迁移 2024-03-18 21:04:09 +08:00
Junyan Qin
cca48a394d Merge pull request #732 from RockChinQ/feat/claude-3
Feat: 接入 claude 3 系列模型
2024-03-18 11:27:22 +08:00
RockChinQ
a723c8ce37 perf: claude 的接口异常处理 2024-03-17 23:22:26 -04:00
RockChinQ
327b2509f6 perf: 忽略用户空消息 2024-03-17 23:06:40 -04:00
RockChinQ
1dae7bd655 feat: 对 claude api 的基本支持 2024-03-17 12:44:45 -04:00
RockChinQ
550a131685 deps: 添加 anthropic 依赖库 2024-03-17 12:03:25 -04:00
RockChinQ
0cfb8bb29f fix: 获取模型列表时未传递version参数 2024-03-16 22:23:02 +08:00
Junyan Qin
9c32420a95 Merge pull request #730 from RockChinQ/feat/customized-model
Feat: 允许自定义模型信息
2024-03-16 22:19:27 +08:00
RockChinQ
867093cc88 chore: 更改 provider.json 格式 2024-03-16 22:12:13 +08:00
RockChinQ
82763f8ec5 chore: 删除默认prompt 2024-03-16 21:43:45 +08:00
RockChinQ
97449065df feat: 通过元数据生成模型列表 2024-03-16 21:43:09 +08:00
Junyan Qin
9489783846 Merge pull request #729 from RockChinQ/feat/migration-stage
Feat: 配置文件迁移功能
2024-03-16 20:34:29 +08:00
RockChinQ
f91c9015bc feat: 添加配置文件迁移阶段 2024-03-16 20:27:17 +08:00
RockChinQ
302d86056d refactor: 所有的 json 加载统一到启动阶段中 2024-03-16 15:41:59 +08:00
Junyan Qin
98bebfddaa Merge pull request #728 from RockChinQ/feat/active-message
Feat: aiocqhttp 和 qq-botpy 适配器的主动消息发送接口
2024-03-16 15:18:27 +08:00
RockChinQ
dab20e3187 feat: aiocqhttp和qq-botpy的主动消息发送接口 2024-03-16 15:16:46 +08:00
RockChinQ
09e72f7c5f chore: 删除注释的代码 2024-03-14 17:24:36 +08:00
Junyan Qin
2028d85f84 Merge pull request #726 from RockChinQ/feat/qq-botpy-cache
Feat: qq-botpy 适配器对 member 和 group 的 openid 进行静态缓存
2024-03-14 16:05:14 +08:00
RockChinQ
ed3c0d9014 feat: qq-botpy 适配器对 member 和 group 的 openid 进行静态缓存 2024-03-14 16:00:22 +08:00
RockChinQ
be06150990 chore: aiocqhttp添加默认access-token参数 2024-03-13 16:53:30 +08:00
Junyan Qin
afb3fb4a31 Merge pull request #725 from RockChinQ/feat/aiocqhttp-access-token
Feat: aiocqhttp支持access-token
2024-03-13 16:49:56 +08:00
RockChinQ
d66577e6c3 feat: aiocqhttp支持access-token 2024-03-13 16:49:11 +08:00
Junyan Qin
6a4ea5446a Merge pull request #724 from RockChinQ/fix/at-resp
Fix: 回复并at机器人时会多一个at组件
2024-03-13 16:31:54 +08:00
RockChinQ
74e84c744a fix: 回复并at机器人时会多一个at组件 2024-03-13 16:31:06 +08:00
Junyan Qin
5ad2446cf3 Update bug-report.yml 2024-03-13 16:13:14 +08:00
Junyan Qin
63303bb5c0 Merge pull request #712 from RockChinQ/feat/component-extensibility
Feat: 更多组件的可扩展性
2024-03-13 00:32:26 +08:00
Junyan Qin
13393b6624 feat: 限速算法的扩展性 2024-03-12 16:31:54 +00:00
Junyan Qin
b9fa11c0c3 feat: prompt 加载器的扩展性 2024-03-12 16:22:07 +00:00
RockChinQ
8c6ce1f030 feat: 群响应规则的扩展性 2024-03-12 23:34:13 +08:00
RockChinQ
1d963d0f0c feat: 不再预先计算前文token数而是在报错时提醒用户重置 2024-03-12 16:04:11 +08:00
Junyan Qin
0ee383be27 Update announcement.json 2024-03-08 22:35:17 +08:00
RockChinQ
53d09129b4 fix: 命令事件的command参数处理错误 (#713) 2024-03-08 21:10:43 +08:00
RockChinQ
a398c6f311 feat: 消息平台适配器可扩展性 2024-03-08 20:40:54 +08:00
RockChinQ
4347ddd42a feat: 长消息处理策略可扩展性 2024-03-08 20:31:22 +08:00
RockChinQ
22cb8a6a06 feat: 内容过滤器的可扩展性 2024-03-08 20:22:06 +08:00
RockChinQ
7f554fd862 feat: command支持扩展命令类 2024-03-08 19:56:57 +08:00
Junyan Qin
a82bfa8a56 perf: 为命令装饰器添加断言 2024-03-08 11:38:26 +00:00
RockChinQ
95784debbf perf: 支持识别docker环境 2024-03-07 15:55:02 +08:00
Junyan Qin
2471c5bf0f Merge pull request #709 from RockChinQ/doc/comments
Doc: 补全部分注释
2024-03-03 16:35:31 +08:00
RockChinQ
2fe6d731b8 doc: 补全部分注释 2024-03-03 16:34:59 +08:00
106 changed files with 1739 additions and 773 deletions

View File

@@ -16,11 +16,13 @@ body:
required: true required: true
- type: dropdown - type: dropdown
attributes: attributes:
label: 登录框架 label: 消息平台适配器
description: "连接QQ使用的框架" description: "连接QQ使用的框架"
options: options:
- Mirai - yiri-miraiMirai
- go-cqhttp - Nakurugo-cqhttp
- aiocqhttp使用 OneBot 协议接入的)
- qq-botpyQQ官方API
validations: validations:
required: false required: false
- type: input - type: input

View File

@@ -5,6 +5,7 @@ COPY . .
RUN apt update \ RUN apt update \
&& apt install gcc -y \ && apt install gcc -y \
&& python -m pip install -r requirements.txt && python -m pip install -r requirements.txt \
&& touch /.dockerenv
CMD [ "python", "main.py" ] CMD [ "python", "main.py" ]

View File

@@ -32,7 +32,7 @@ async def main_entry():
sys.exit(0) sys.exit(0)
# 检查配置文件 # 检查配置文件
from pkg.core.bootutils import files from pkg.core.bootutils import files
generated_files = await files.generate_files() generated_files = await files.generate_files()

View File

@@ -34,6 +34,9 @@ class APIGroup(metaclass=abc.ABCMeta):
headers: dict = {}, headers: dict = {},
**kwargs **kwargs
): ):
"""
执行请求
"""
self._runtime_info['account_id'] = "-1" self._runtime_info['account_id'] = "-1"
url = self.prefix + path url = self.prefix + path

View File

@@ -1,3 +1,5 @@
# 实例 识别码 控制
import os import os
import uuid import uuid
import json import json

View File

@@ -7,6 +7,7 @@ from ..provider import entities as llm_entities
from . import entities, operator, errors from . import entities, operator, errors
from ..config import manager as cfg_mgr from ..config import manager as cfg_mgr
# 引入所有算子以便注册
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update
@@ -17,6 +18,9 @@ class CommandManager:
ap: app.Application ap: app.Application
cmd_list: list[operator.CommandOperator] cmd_list: list[operator.CommandOperator]
"""
运行时命令列表,扁平存储,各个对象包含对应的子节点引用
"""
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
self.ap = ap self.ap = ap
@@ -60,7 +64,7 @@ class CommandManager:
""" """
found = False found = False
if len(context.crt_params) > 0: if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
for oper in operator_list: for oper in operator_list:
if (context.crt_params[0] == oper.name \ if (context.crt_params[0] == oper.name \
or context.crt_params[0] in oper.alias) \ or context.crt_params[0] in oper.alias) \
@@ -78,7 +82,7 @@ class CommandManager:
yield ret yield ret
break break
if not found: if not found: # 如果下一个参数未在此节点的子节点中找到,则执行此节点或者报错
if operator is None: if operator is None:
yield entities.CommandReturn( yield entities.CommandReturn(
error=errors.CommandNotFoundError(context.crt_params[0]) error=errors.CommandNotFoundError(context.crt_params[0])

View File

@@ -10,6 +10,8 @@ from . import errors, operator
class CommandReturn(pydantic.BaseModel): class CommandReturn(pydantic.BaseModel):
"""命令返回值
"""
text: typing.Optional[str] text: typing.Optional[str]
"""文本 """文本
@@ -18,25 +20,52 @@ class CommandReturn(pydantic.BaseModel):
image: typing.Optional[mirai.Image] image: typing.Optional[mirai.Image]
error: typing.Optional[errors.CommandError]= None error: typing.Optional[errors.CommandError]= None
"""错误
"""
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
class ExecuteContext(pydantic.BaseModel): class ExecuteContext(pydantic.BaseModel):
"""单次命令执行上下文
"""
query: core_entities.Query query: core_entities.Query
"""本次消息的请求对象"""
session: core_entities.Session session: core_entities.Session
"""本次消息所属的会话对象"""
command_text: str command_text: str
"""命令完整文本"""
command: str command: str
"""命令名称"""
crt_command: str crt_command: str
"""当前命令
多级命令中crt_command为当前命令command为根命令。
例如:!plugin on Webwlkr
处理到plugin时command为plugincrt_command为plugin
处理到on时command为plugincrt_command为on
"""
params: list[str] params: list[str]
"""命令参数
整个命令以空格分割后的参数列表
"""
crt_params: list[str] crt_params: list[str]
"""当前命令参数
多级命令中crt_params为当前命令参数params为根命令参数。
例如:!plugin on Webwlkr
处理到plugin时params为['on', 'Webwlkr']crt_params为['on', 'Webwlkr']
处理到on时params为['on', 'Webwlkr']crt_params为['Webwlkr']
"""
privilege: int privilege: int
"""发起人权限"""

View File

@@ -8,17 +8,34 @@ from . import entities
preregistered_operators: list[typing.Type[CommandOperator]] = [] preregistered_operators: list[typing.Type[CommandOperator]] = []
"""预注册命令算子列表。在初始化时,所有算子类会被注册到此列表中。"""
def operator_class( def operator_class(
name: str, name: str,
help: str, help: str = "",
usage: str = None, usage: str = None,
alias: list[str] = [], alias: list[str] = [],
privilege: int=1, # 1为普通用户2为管理员 privilege: int=1, # 1为普通用户2为管理员
parent_class: typing.Type[CommandOperator] = None parent_class: typing.Type[CommandOperator] = None
) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: ) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]:
"""命令类装饰器
Args:
name (str): 名称
help (str, optional): 帮助信息. Defaults to "".
usage (str, optional): 使用说明. Defaults to None.
alias (list[str], optional): 别名. Defaults to [].
privilege (int, optional): 权限1为普通用户可用2为仅管理员可用. Defaults to 1.
parent_class (typing.Type[CommandOperator], optional): 父节点若为None则为顶级命令. Defaults to None.
Returns:
typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: 装饰器
"""
def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]: def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]:
assert issubclass(cls, CommandOperator)
cls.name = name cls.name = name
cls.alias = alias cls.alias = alias
cls.help = help cls.help = help
@@ -34,7 +51,12 @@ def operator_class(
class CommandOperator(metaclass=abc.ABCMeta): class CommandOperator(metaclass=abc.ABCMeta):
"""命令算子 """命令算子抽象类
以下的参数均不需要在子类中设置,只需要在使用装饰器注册类时作为参数传递即可。
命令支持级联,即一个命令可以有多个子命令,子命令可以有子命令,以此类推。
处理命令时,若有子命令,会以当前参数列表的第一个参数去匹配子命令,若匹配成功,则转移到子命令中执行。
若没有匹配成功或没有子命令,则执行当前命令。
""" """
ap: app.Application ap: app.Application
@@ -43,7 +65,8 @@ class CommandOperator(metaclass=abc.ABCMeta):
"""名称,搜索到时若符合则使用""" """名称,搜索到时若符合则使用"""
path: str path: str
"""路径所有父节点的name的连接用于定义命令权限""" """路径所有父节点的name的连接用于定义命令权限,由管理器在初始化时自动设置。
"""
alias: list[str] alias: list[str]
"""同name""" """同name"""
@@ -52,6 +75,7 @@ class CommandOperator(metaclass=abc.ABCMeta):
"""此节点的帮助信息""" """此节点的帮助信息"""
usage: str = None usage: str = None
"""用法"""
parent_class: typing.Union[typing.Type[CommandOperator], None] = None parent_class: typing.Union[typing.Type[CommandOperator], None] = None
"""父节点类。标记以供管理器在初始化时编织父子关系。""" """父节点类。标记以供管理器在初始化时编织父子关系。"""
@@ -75,4 +99,15 @@ class CommandOperator(metaclass=abc.ABCMeta):
self, self,
context: entities.ExecuteContext context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""实现此方法以执行命令
支持多次yield以返回多个结果。
例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。
Args:
context (entities.ExecuteContext): 命令执行上下文
Yields:
entities.CommandReturn: 命令返回封装
"""
pass pass

View File

@@ -8,15 +8,12 @@ from .. import model as file_model
class JSONConfigFile(file_model.ConfigFile): class JSONConfigFile(file_model.ConfigFile):
"""JSON配置文件""" """JSON配置文件"""
config_file_name: str = None def __init__(
"""配置文件名""" self, config_file_name: str, template_file_name: str = None, template_data: dict = None
) -> None:
template_file_name: str = None
"""模板文件名"""
def __init__(self, config_file_name: str, template_file_name: str) -> None:
self.config_file_name = config_file_name self.config_file_name = config_file_name
self.template_file_name = template_file_name self.template_file_name = template_file_name
self.template_data = template_data
def exists(self) -> bool: def exists(self) -> bool:
return os.path.exists(self.config_file_name) return os.path.exists(self.config_file_name)
@@ -29,19 +26,24 @@ class JSONConfigFile(file_model.ConfigFile):
if not self.exists(): if not self.exists():
await self.create() await self.create()
with open(self.config_file_name, 'r', encoding='utf-8') as f: if self.template_file_name is not None:
cfg = json.load(f) with open(self.config_file_name, "r", encoding="utf-8") as f:
cfg = json.load(f)
# 从模板文件中进行补全 # 从模板文件中进行补全
with open(self.template_file_name, 'r', encoding='utf-8') as f: with open(self.template_file_name, "r", encoding="utf-8") as f:
template_cfg = json.load(f) self.template_data = json.load(f)
for key in template_cfg: for key in self.template_data:
if key not in cfg: if key not in cfg:
cfg[key] = template_cfg[key] cfg[key] = self.template_data[key]
return cfg return cfg
async def save(self, cfg: dict): async def save(self, cfg: dict):
with open(self.config_file_name, 'w', encoding='utf-8') as f: with open(self.config_file_name, "w", encoding="utf-8") as f:
json.dump(cfg, f, indent=4, ensure_ascii=False) json.dump(cfg, f, indent=4, ensure_ascii=False)
def save_sync(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f:
json.dump(cfg, f, indent=4, ensure_ascii=False)

View File

@@ -60,3 +60,6 @@ class PythonModuleConfigFile(file_model.ConfigFile):
async def save(self, data: dict): async def save(self, data: dict):
logging.warning('Python模块配置文件不支持保存') logging.warning('Python模块配置文件不支持保存')
def save_sync(self, data: dict):
logging.warning('Python模块配置文件不支持保存')

View File

@@ -26,6 +26,9 @@ class ConfigManager:
async def dump_config(self): async def dump_config(self):
await self.file.save(self.data) await self.file.save(self.data)
def dump_config_sync(self):
self.file.save_sync(self.data)
async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager: async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager:
"""加载Python模块配置文件""" """加载Python模块配置文件"""
@@ -40,11 +43,12 @@ async def load_python_module_config(config_name: str, template_name: str) -> Con
return cfg_mgr return cfg_mgr
async def load_json_config(config_name: str, template_name: str) -> ConfigManager: async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None) -> ConfigManager:
"""加载JSON配置文件""" """加载JSON配置文件"""
cfg_inst = json_file.JSONConfigFile( cfg_inst = json_file.JSONConfigFile(
config_name, config_name,
template_name template_name,
template_data
) )
cfg_mgr = ConfigManager(cfg_inst) cfg_mgr = ConfigManager(cfg_inst)

47
pkg/config/migration.py Normal file
View File

@@ -0,0 +1,47 @@
from __future__ import annotations
import abc
import typing
from ..core import app
preregistered_migrations: list[typing.Type[Migration]] = []
"""当前阶段暂不支持扩展"""
def migration_class(name: str, number: int):
"""注册一个迁移
"""
def decorator(cls: typing.Type[Migration]) -> typing.Type[Migration]:
cls.name = name
cls.number = number
preregistered_migrations.append(cls)
return cls
return decorator
class Migration(abc.ABC):
"""一个版本的迁移
"""
name: str
number: int
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
@abc.abstractmethod
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
pass
@abc.abstractmethod
async def run(self):
"""执行迁移
"""
pass

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
import os
import sys
from .. import migration
@migration.migration_class("sensitive-word-migration", 1)
class SensitiveWordMigration(migration.Migration):
"""敏感词迁移
"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return os.path.exists("data/config/sensitive-words.json")
async def run(self):
"""执行迁移
"""
# 移动文件
os.rename("data/config/sensitive-words.json", "data/metadata/sensitive-words.json")
# 重新加载配置
await self.ap.sensitive_meta.load_config()

View File

@@ -0,0 +1,47 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("openai-config-migration", 2)
class OpenAIConfigMigration(migration.Migration):
"""OpenAI配置迁移
"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'openai-config' in self.ap.provider_cfg.data
async def run(self):
"""执行迁移
"""
old_openai_config = self.ap.provider_cfg.data['openai-config'].copy()
if 'keys' not in self.ap.provider_cfg.data:
self.ap.provider_cfg.data['keys'] = {}
if 'openai' not in self.ap.provider_cfg.data['keys']:
self.ap.provider_cfg.data['keys']['openai'] = []
self.ap.provider_cfg.data['keys']['openai'] = old_openai_config['api-keys']
self.ap.provider_cfg.data['model'] = old_openai_config['chat-completions-params']['model']
del old_openai_config['chat-completions-params']['model']
if 'requester' not in self.ap.provider_cfg.data:
self.ap.provider_cfg.data['requester'] = {}
if 'openai-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['openai-chat-completions'] = {}
self.ap.provider_cfg.data['requester']['openai-chat-completions'] = {
'base-url': old_openai_config['base_url'],
'args': old_openai_config['chat-completions-params'],
'timeout': old_openai_config['request-timeout'],
}
del self.ap.provider_cfg.data['openai-config']
await self.ap.provider_cfg.dump_config()

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("anthropic-requester-config-completion", 3)
class AnthropicRequesterConfigCompletionMigration(migration.Migration):
"""OpenAI配置迁移
"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'anthropic-messages' not in self.ap.provider_cfg.data['requester'] \
or 'anthropic' not in self.ap.provider_cfg.data['keys']
async def run(self):
"""执行迁移
"""
if 'anthropic-messages' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['anthropic-messages'] = {
'base-url': 'https://api.anthropic.com',
'args': {
'max_tokens': 1024
},
'timeout': 120,
}
if 'anthropic' not in self.ap.provider_cfg.data['keys']:
self.ap.provider_cfg.data['keys']['anthropic'] = []
await self.ap.provider_cfg.dump_config()

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("moonshot-config-completion", 4)
class MoonshotConfigCompletionMigration(migration.Migration):
"""OpenAI配置迁移
"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移
"""
return 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester'] \
or 'moonshot' not in self.ap.provider_cfg.data['keys']
async def run(self):
"""执行迁移
"""
if 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester']:
self.ap.provider_cfg.data['requester']['moonshot-chat-completions'] = {
'base-url': 'https://api.moonshot.cn/v1',
'args': {},
'timeout': 120,
}
if 'moonshot' not in self.ap.provider_cfg.data['keys']:
self.ap.provider_cfg.data['keys']['moonshot'] = []
await self.ap.provider_cfg.dump_config()

View File

@@ -10,6 +10,9 @@ class ConfigFile(metaclass=abc.ABCMeta):
template_file_name: str = None template_file_name: str = None
"""模板文件名""" """模板文件名"""
template_data: dict = None
"""模板数据"""
@abc.abstractmethod @abc.abstractmethod
def exists(self) -> bool: def exists(self) -> bool:
pass pass
@@ -25,3 +28,7 @@ class ConfigFile(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def save(self, data: dict): async def save(self, data: dict):
pass pass
@abc.abstractmethod
def save_sync(self, data: dict):
pass

View File

@@ -6,7 +6,7 @@ import traceback
from ..platform import manager as im_mgr from ..platform import manager as im_mgr
from ..provider.session import sessionmgr as llm_session_mgr from ..provider.session import sessionmgr as llm_session_mgr
from ..provider.requester import modelmgr as llm_model_mgr from ..provider.modelmgr import modelmgr as llm_model_mgr
from ..provider.sysprompt import sysprompt as llm_prompt_mgr from ..provider.sysprompt import sysprompt as llm_prompt_mgr
from ..provider.tools import toolmgr as llm_tool_mgr from ..provider.tools import toolmgr as llm_tool_mgr
from ..config import manager as config_mgr from ..config import manager as config_mgr
@@ -19,7 +19,9 @@ from ..utils import version as version_mgr, proxy as proxy_mgr
class Application: class Application:
im_mgr: im_mgr.PlatformManager = None """运行时应用对象和上下文"""
platform_mgr: im_mgr.PlatformManager = None
cmd_mgr: cmdmgr.CommandManager = None cmd_mgr: cmdmgr.CommandManager = None
@@ -31,6 +33,8 @@ class Application:
tool_mgr: llm_tool_mgr.ToolManager = None tool_mgr: llm_tool_mgr.ToolManager = None
# ======= 配置管理器 =======
command_cfg: config_mgr.ConfigManager = None command_cfg: config_mgr.ConfigManager = None
pipeline_cfg: config_mgr.ConfigManager = None pipeline_cfg: config_mgr.ConfigManager = None
@@ -41,6 +45,18 @@ class Application:
system_cfg: config_mgr.ConfigManager = None system_cfg: config_mgr.ConfigManager = None
# ======= 元数据配置管理器 =======
sensitive_meta: config_mgr.ConfigManager = None
adapter_qq_botpy_meta: config_mgr.ConfigManager = None
plugin_setting_meta: config_mgr.ConfigManager = None
llm_models_meta: config_mgr.ConfigManager = None
# =========================
ctr_mgr: center_mgr.V2CenterAPI = None ctr_mgr: center_mgr.V2CenterAPI = None
plugin_mgr: plugin_mgr.PluginManager = None plugin_mgr: plugin_mgr.PluginManager = None
@@ -64,27 +80,18 @@ class Application:
pass pass
async def run(self): async def run(self):
await self.plugin_mgr.load_plugins()
await self.plugin_mgr.initialize_plugins() await self.plugin_mgr.initialize_plugins()
tasks = [] tasks = []
try: try:
tasks = [ tasks = [
asyncio.create_task(self.im_mgr.run()), asyncio.create_task(self.platform_mgr.run()),
asyncio.create_task(self.ctrl.run()) asyncio.create_task(self.ctrl.run())
] ]
# async def interrupt(tasks): # 挂信号处理
# await asyncio.sleep(1.5)
# while await aioconsole.ainput("使用 ctrl+c 或 'exit' 退出程序 > ") != 'exit':
# pass
# for task in tasks:
# task.cancel()
# await interrupt(tasks)
import signal import signal

View File

@@ -3,11 +3,14 @@ from __future__ import print_function
from . import app from . import app
from ..audit import identifier from ..audit import identifier
from . import stage from . import stage
from .stages import load_config, setup_logger, build_app
# 引入启动阶段实现以便注册
from .stages import load_config, setup_logger, build_app, migrate
stage_order = [ stage_order = [
"LoadConfigStage", "LoadConfigStage",
"MigrationStage",
"SetupLoggerStage", "SetupLoggerStage",
"BuildAppStage" "BuildAppStage"
] ]
@@ -20,6 +23,7 @@ async def make_app() -> app.Application:
ap = app.Application() ap = app.Application()
# 执行启动阶段
for stage_name in stage_order: for stage_name in stage_order:
stage_cls = stage.preregistered_stages[stage_name] stage_cls = stage.preregistered_stages[stage_name]
stage_inst = stage_cls() stage_inst = stage_cls()

View File

@@ -3,13 +3,13 @@ import pip
required_deps = { required_deps = {
"requests": "requests", "requests": "requests",
"openai": "openai", "openai": "openai",
"anthropic": "anthropic",
"colorlog": "colorlog", "colorlog": "colorlog",
"mirai": "yiri-mirai-rc", "mirai": "yiri-mirai-rc",
"aiocqhttp": "aiocqhttp", "aiocqhttp": "aiocqhttp",
"botpy": "qq-botpy", "botpy": "qq-botpy",
"PIL": "pillow", "PIL": "pillow",
"nakuru": "nakuru-project-idk", "nakuru": "nakuru-project-idk",
"CallingGPT": "CallingGPT",
"tiktoken": "tiktoken", "tiktoken": "tiktoken",
"yaml": "pyyaml", "yaml": "pyyaml",
"aiohttp": "aiohttp", "aiohttp": "aiohttp",

View File

@@ -13,13 +13,13 @@ required_files = {
"data/config/platform.json": "templates/platform.json", "data/config/platform.json": "templates/platform.json",
"data/config/provider.json": "templates/provider.json", "data/config/provider.json": "templates/provider.json",
"data/config/system.json": "templates/system.json", "data/config/system.json": "templates/system.json",
"data/config/sensitive-words.json": "templates/sensitive-words.json",
"data/scenario/default.json": "templates/scenario-template.json", "data/scenario/default.json": "templates/scenario-template.json",
} }
required_paths = [ required_paths = [
"temp", "temp",
"data", "data",
"data/metadata",
"data/prompts", "data/prompts",
"data/scenario", "data/scenario",
"data/logs", "data/logs",

View File

@@ -9,13 +9,14 @@ import pydantic
import mirai import mirai
from ..provider import entities as llm_entities from ..provider import entities as llm_entities
from ..provider.requester import entities from ..provider.modelmgr import entities
from ..provider.sysprompt import entities as sysprompt_entities from ..provider.sysprompt import entities as sysprompt_entities
from ..provider.tools import entities as tools_entities from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter from ..platform import adapter as msadapter
class LauncherTypes(enum.Enum): class LauncherTypes(enum.Enum):
"""一个请求的发起者类型"""
PERSON = 'person' PERSON = 'person'
"""私聊""" """私聊"""
@@ -31,43 +32,43 @@ class Query(pydantic.BaseModel):
"""请求ID添加进请求池时生成""" """请求ID添加进请求池时生成"""
launcher_type: LauncherTypes launcher_type: LauncherTypes
"""会话类型platform设置""" """会话类型platform处理阶段设置"""
launcher_id: int launcher_id: int
"""会话IDplatform设置""" """会话IDplatform处理阶段设置"""
sender_id: int sender_id: int
"""发送者IDplatform设置""" """发送者IDplatform处理阶段设置"""
message_event: mirai.MessageEvent message_event: mirai.MessageEvent
"""事件platform收到的事件""" """事件platform收到的原始事件"""
message_chain: mirai.MessageChain message_chain: mirai.MessageChain
"""消息链platform收到的消息链""" """消息链platform收到的原始消息链"""
adapter: msadapter.MessageSourceAdapter adapter: msadapter.MessageSourceAdapter
"""适配器对象""" """消息平台适配器对象单个app中可能启用了多个消息平台适配器此对象表明发起此query的适配器"""
session: typing.Optional[Session] = None session: typing.Optional[Session] = None
"""会话对象,由前置处理器设置""" """会话对象,由前置处理器阶段设置"""
messages: typing.Optional[list[llm_entities.Message]] = [] messages: typing.Optional[list[llm_entities.Message]] = []
"""历史消息列表,由前置处理器设置""" """历史消息列表,由前置处理器阶段设置"""
prompt: typing.Optional[sysprompt_entities.Prompt] = None prompt: typing.Optional[sysprompt_entities.Prompt] = None
"""情景预设内容,由前置处理器设置""" """情景预设内容,由前置处理器阶段设置"""
user_message: typing.Optional[llm_entities.Message] = None user_message: typing.Optional[llm_entities.Message] = None
"""此次请求的用户消息对象,由前置处理器设置""" """此次请求的用户消息对象,由前置处理器阶段设置"""
use_model: typing.Optional[entities.LLMModelInfo] = None use_model: typing.Optional[entities.LLMModelInfo] = None
"""使用的模型,由前置处理器设置""" """使用的模型,由前置处理器阶段设置"""
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器设置""" """使用的函数,由前置处理器阶段设置"""
resp_messages: typing.Optional[list[llm_entities.Message]] = [] resp_messages: typing.Optional[list[llm_entities.Message]] = []
"""provider生成的回复消息对象列表""" """Process阶段生成的回复消息对象列表"""
resp_message_chain: typing.Optional[mirai.MessageChain] = None resp_message_chain: typing.Optional[mirai.MessageChain] = None
"""回复消息链从resp_messages包装而得""" """回复消息链从resp_messages包装而得"""
@@ -77,7 +78,7 @@ class Query(pydantic.BaseModel):
class Conversation(pydantic.BaseModel): class Conversation(pydantic.BaseModel):
"""对话""" """对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation但只有一个当前使用的 Conversation"""
prompt: sysprompt_entities.Prompt prompt: sysprompt_entities.Prompt
@@ -93,7 +94,7 @@ class Conversation(pydantic.BaseModel):
class Session(pydantic.BaseModel): class Session(pydantic.BaseModel):
"""会话""" """会话,一个 Session 对应一个 {launcher_type}_{launcher_id}"""
launcher_type: LauncherTypes launcher_type: LauncherTypes
launcher_id: int launcher_id: int
@@ -111,6 +112,7 @@ class Session(pydantic.BaseModel):
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
semaphore: typing.Optional[asyncio.Semaphore] = None semaphore: typing.Optional[asyncio.Semaphore] = None
"""当前会话的信号量,用于限制并发"""
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True

View File

@@ -7,6 +7,10 @@ from . import app
preregistered_stages: dict[str, typing.Type[BootingStage]] = {} preregistered_stages: dict[str, typing.Type[BootingStage]] = {}
"""预注册的请求处理阶段。在初始化时,所有请求处理阶段类会被注册到此字典中。
当前阶段暂不支持扩展
"""
def stage_class( def stage_class(
name: str name: str

View File

@@ -3,14 +3,14 @@ from __future__ import annotations
import sys import sys
from .. import stage, app from .. import stage, app
from ...utils import version, proxy, announce from ...utils import version, proxy, announce, platform
from ...audit.center import v2 as center_v2 from ...audit.center import v2 as center_v2
from ...audit import identifier from ...audit import identifier
from ...pipeline import pool, controller, stagemgr from ...pipeline import pool, controller, stagemgr
from ...plugin import manager as plugin_mgr from ...plugin import manager as plugin_mgr
from ...command import cmdmgr from ...command import cmdmgr
from ...provider.session import sessionmgr as llm_session_mgr from ...provider.session import sessionmgr as llm_session_mgr
from ...provider.requester import modelmgr as llm_model_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr
from ...provider.sysprompt import sysprompt as llm_prompt_mgr from ...provider.sysprompt import sysprompt as llm_prompt_mgr
from ...provider.tools import toolmgr as llm_tool_mgr from ...provider.tools import toolmgr as llm_tool_mgr
from ...platform import manager as im_mgr from ...platform import manager as im_mgr
@@ -22,7 +22,7 @@ class BuildAppStage(stage.BootingStage):
""" """
async def run(self, ap: app.Application): async def run(self, ap: app.Application):
"""启动 """构建app对象的各个组件对象并初始化
""" """
proxy_mgr = proxy.ProxyManager(ap) proxy_mgr = proxy.ProxyManager(ap)
@@ -39,7 +39,7 @@ class BuildAppStage(stage.BootingStage):
"host_id": identifier.identifier["host_id"], "host_id": identifier.identifier["host_id"],
"instance_id": identifier.identifier["instance_id"], "instance_id": identifier.identifier["instance_id"],
"semantic_version": ver_mgr.get_current_version(), "semantic_version": ver_mgr.get_current_version(),
"platform": sys.platform, "platform": platform.get_platform(),
}, },
runtime_info={ runtime_info={
"admin_id": "{}".format(ap.system_cfg.data["admin-sessions"]), "admin_id": "{}".format(ap.system_cfg.data["admin-sessions"]),
@@ -62,6 +62,7 @@ class BuildAppStage(stage.BootingStage):
plugin_mgr_inst = plugin_mgr.PluginManager(ap) plugin_mgr_inst = plugin_mgr.PluginManager(ap)
await plugin_mgr_inst.initialize() await plugin_mgr_inst.initialize()
ap.plugin_mgr = plugin_mgr_inst ap.plugin_mgr = plugin_mgr_inst
await plugin_mgr_inst.load_plugins()
cmd_mgr_inst = cmdmgr.CommandManager(ap) cmd_mgr_inst = cmdmgr.CommandManager(ap)
await cmd_mgr_inst.initialize() await cmd_mgr_inst.initialize()
@@ -85,7 +86,7 @@ class BuildAppStage(stage.BootingStage):
im_mgr_inst = im_mgr.PlatformManager(ap=ap) im_mgr_inst = im_mgr.PlatformManager(ap=ap)
await im_mgr_inst.initialize() await im_mgr_inst.initialize()
ap.im_mgr = im_mgr_inst ap.platform_mgr = im_mgr_inst
stage_mgr = stagemgr.StageManager(ap) stage_mgr = stagemgr.StageManager(ap)
await stage_mgr.initialize() await stage_mgr.initialize()

View File

@@ -17,3 +17,15 @@ class LoadConfigStage(stage.BootingStage):
ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json") ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json")
ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json") ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json")
ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json") ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json")
ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json")
await ap.plugin_setting_meta.dump_config()
ap.sensitive_meta = await config.load_json_config("data/metadata/sensitive-words.json", "templates/metadata/sensitive-words.json")
await ap.sensitive_meta.dump_config()
ap.adapter_qq_botpy_meta = await config.load_json_config("data/metadata/adapter-qq-botpy.json", "templates/metadata/adapter-qq-botpy.json")
await ap.adapter_qq_botpy_meta.dump_config()
ap.llm_models_meta = await config.load_json_config("data/metadata/llm-models.json", "templates/metadata/llm-models.json")
await ap.llm_models_meta.dump_config()

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
import importlib
from .. import stage, app
from ...config import migration
from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
@stage.stage_class("MigrationStage")
class MigrationStage(stage.BootingStage):
"""迁移阶段
"""
async def run(self, ap: app.Application):
"""启动
"""
migrations = migration.preregistered_migrations
# 按照迁移号排序
migrations.sort(key=lambda x: x.number)
for migration_cls in migrations:
migration_instance = migration_cls(ap)
if await migration_instance.need_migrate():
await migration_instance.run()

View File

@@ -8,6 +8,7 @@ from ...config import manager as cfg_mgr
@stage.stage_class('BanSessionCheckStage') @stage.stage_class('BanSessionCheckStage')
class BanSessionCheckStage(stage.PipelineStage): class BanSessionCheckStage(stage.PipelineStage):
"""访问控制处理阶段"""
async def initialize(self): async def initialize(self):
pass pass

View File

@@ -7,28 +7,38 @@ from ...core import app
from .. import stage, entities, stagemgr from .. import stage, entities, stagemgr
from ...core import entities as core_entities from ...core import entities as core_entities
from ...config import manager as cfg_mgr from ...config import manager as cfg_mgr
from . import filter, entities as filter_entities from . import filter as filter_model, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine from .filters import cntignore, banwords, baiduexamine
@stage.stage_class('PostContentFilterStage') @stage.stage_class('PostContentFilterStage')
@stage.stage_class('PreContentFilterStage') @stage.stage_class('PreContentFilterStage')
class ContentFilterStage(stage.PipelineStage): class ContentFilterStage(stage.PipelineStage):
"""内容过滤阶段"""
filter_chain: list[filter.ContentFilter] filter_chain: list[filter_model.ContentFilter]
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
self.filter_chain = [] self.filter_chain = []
super().__init__(ap) super().__init__(ap)
async def initialize(self): async def initialize(self):
self.filter_chain.append(cntignore.ContentIgnore(self.ap))
filters_required = [
"content-filter"
]
if self.ap.pipeline_cfg.data['check-sensitive-words']: if self.ap.pipeline_cfg.data['check-sensitive-words']:
self.filter_chain.append(banwords.BanWordFilter(self.ap)) filters_required.append("ban-word-filter")
if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']: if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap)) filters_required.append("baidu-cloud-examine")
for filter in filter_model.preregistered_filters:
if filter.name in filters_required:
self.filter_chain.append(
filter(self.ap)
)
for filter in self.filter_chain: for filter in self.filter_chain:
await filter.initialize() await filter.initialize()

View File

@@ -31,15 +31,24 @@ class EnableStage(enum.Enum):
class FilterResult(pydantic.BaseModel): class FilterResult(pydantic.BaseModel):
level: ResultLevel level: ResultLevel
"""结果等级
对于前置处理阶段,只要有任意一个返回 非PASS 的内容过滤器结果,就会中断处理。
对于后置处理阶段,当且内容过滤器返回 BLOCK 时,会中断处理。
"""
replacement: str replacement: str
"""替换后的消息""" """替换后的消息
内容过滤器可以进行一些遮掩处理,然后把遮掩后的消息返回。
若没有修改内容,也需要返回原消息。
"""
user_notice: str user_notice: str
"""不通过时,用户提示消息""" """不通过时,若此值不为空,将对用户提示消息"""
console_notice: str console_notice: str
"""不通过时,控制台提示消息""" """不通过时,若此值不为空,将在控制台提示消息"""
class ManagerResultLevel(enum.Enum): class ManagerResultLevel(enum.Enum):

View File

@@ -1,12 +1,42 @@
# 内容过滤器的抽象类 # 内容过滤器的抽象类
from __future__ import annotations from __future__ import annotations
import abc import abc
import typing
from ...core import app from ...core import app
from . import entities from . import entities
preregistered_filters: list[typing.Type[ContentFilter]] = []
def filter_class(
name: str
) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]:
"""内容过滤器类装饰器
Args:
name (str): 过滤器名称
Returns:
typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: 装饰器
"""
def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]:
assert issubclass(cls, ContentFilter)
cls.name = name
preregistered_filters.append(cls)
return cls
return decorator
class ContentFilter(metaclass=abc.ABCMeta): class ContentFilter(metaclass=abc.ABCMeta):
"""内容过滤器抽象类"""
name: str
ap: app.Application ap: app.Application
@@ -16,6 +46,11 @@ class ContentFilter(metaclass=abc.ABCMeta):
@property @property
def enable_stages(self): def enable_stages(self):
"""启用的阶段 """启用的阶段
默认为消息请求AI前后的两个阶段。
entity.EnableStage.PRE: 消息请求AI前此时需要检查的内容是用户的输入消息。
entity.EnableStage.POST: 消息请求AI后此时需要检查的内容是AI的回复消息。
""" """
return [ return [
entities.EnableStage.PRE, entities.EnableStage.PRE,
@@ -30,5 +65,14 @@ class ContentFilter(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def process(self, message: str) -> entities.FilterResult: async def process(self, message: str) -> entities.FilterResult:
"""处理消息 """处理消息
分为前后阶段,具体取决于 enable_stages 的值。
对于内容过滤器来说,不需要考虑消息所处的阶段,只需要检查消息内容即可。
Args:
message (str): 需要检查的内容
Returns:
entities.FilterResult: 过滤结果,具体内容请查看 entities.FilterResult 类的文档
""" """
raise NotImplementedError raise NotImplementedError

View File

@@ -10,6 +10,7 @@ BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token" BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
@filter_model.filter_class("baidu-cloud-examine")
class BaiduCloudExamine(filter_model.ContentFilter): class BaiduCloudExamine(filter_model.ContentFilter):
"""百度云内容审核""" """百度云内容审核"""

View File

@@ -6,34 +6,30 @@ from .. import entities
from ....config import manager as cfg_mgr from ....config import manager as cfg_mgr
@filter_model.filter_class("ban-word-filter")
class BanWordFilter(filter_model.ContentFilter): class BanWordFilter(filter_model.ContentFilter):
"""根据内容禁言""" """根据内容禁言"""
sensitive: cfg_mgr.ConfigManager
async def initialize(self): async def initialize(self):
self.sensitive = await cfg_mgr.load_json_config( pass
"data/config/sensitive-words.json",
"templates/sensitive-words.json"
)
async def process(self, message: str) -> entities.FilterResult: async def process(self, message: str) -> entities.FilterResult:
found = False found = False
for word in self.sensitive.data['words']: for word in self.ap.sensitive_meta.data['words']:
match = re.findall(word, message) match = re.findall(word, message)
if len(match) > 0: if len(match) > 0:
found = True found = True
for i in range(len(match)): for i in range(len(match)):
if self.sensitive.data['mask_word'] == "": if self.ap.sensitive_meta.data['mask_word'] == "":
message = message.replace( message = message.replace(
match[i], self.sensitive.data['mask'] * len(match[i]) match[i], self.ap.sensitive_meta.data['mask'] * len(match[i])
) )
else: else:
message = message.replace( message = message.replace(
match[i], self.sensitive.data['mask_word'] match[i], self.ap.sensitive_meta.data['mask_word']
) )
return entities.FilterResult( return entities.FilterResult(

View File

@@ -5,6 +5,7 @@ from .. import entities
from .. import filter as filter_model from .. import filter as filter_model
@filter_model.filter_class("content-ignore")
class ContentIgnore(filter_model.ContentFilter): class ContentIgnore(filter_model.ContentFilter):
"""根据内容忽略消息""" """根据内容忽略消息"""

View File

@@ -68,7 +68,7 @@ class Controller:
"""检查输出 """检查输出
""" """
if result.user_notice: if result.user_notice:
await self.ap.im_mgr.send( await self.ap.platform_mgr.send(
query.message_event, query.message_event,
result.user_notice, result.user_notice,
query.adapter query.adapter
@@ -85,7 +85,7 @@ class Controller:
stage_index: int, stage_index: int,
query: entities.Query, query: entities.Query,
): ):
"""从指定阶段开始执行 """从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
如何看懂这里为什么这么写? 如何看懂这里为什么这么写?
去问 GPT-4: 去问 GPT-4:

View File

@@ -15,6 +15,8 @@ from ...config import manager as cfg_mgr
@stage.stage_class("LongTextProcessStage") @stage.stage_class("LongTextProcessStage")
class LongTextProcessStage(stage.PipelineStage): class LongTextProcessStage(stage.PipelineStage):
"""长消息处理阶段
"""
strategy_impl: strategy.LongTextStrategy strategy_impl: strategy.LongTextStrategy
@@ -43,11 +45,14 @@ class LongTextProcessStage(stage.PipelineStage):
self.ap.logger.error("加载字体文件失败({})更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。".format(use_font)) self.ap.logger.error("加载字体文件失败({})更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。".format(use_font))
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward" self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
if config['strategy'] == 'image': for strategy_cls in strategy.preregistered_strategies:
self.strategy_impl = image.Text2ImageStrategy(self.ap) if strategy_cls.name == config['strategy']:
elif config['strategy'] == 'forward': self.strategy_impl = strategy_cls(self.ap)
self.strategy_impl = forward.ForwardComponentStrategy(self.ap) break
else:
raise ValueError(f"未找到名为 {config['strategy']} 的长消息处理策略")
await self.strategy_impl.initialize() await self.strategy_impl.initialize()
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:

View File

@@ -36,6 +36,7 @@ class Forward(MessageComponent):
return '[聊天记录]' return '[聊天记录]'
@strategy_model.strategy_class("forward")
class ForwardComponentStrategy(strategy_model.LongTextStrategy): class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:

View File

@@ -15,6 +15,7 @@ from .. import strategy as strategy_model
from ....core import entities as core_entities from ....core import entities as core_entities
@strategy_model.strategy_class("image")
class Text2ImageStrategy(strategy_model.LongTextStrategy): class Text2ImageStrategy(strategy_model.LongTextStrategy):
text_render_font: ImageFont.FreeTypeFont text_render_font: ImageFont.FreeTypeFont

View File

@@ -9,7 +9,39 @@ from ...core import app
from ...core import entities as core_entities from ...core import entities as core_entities
preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
def strategy_class(
name: str
) -> typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]:
"""长文本处理策略类装饰器
Args:
name (str): 策略名称
Returns:
typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]: 装饰器
"""
def decorator(cls: typing.Type[LongTextStrategy]) -> typing.Type[LongTextStrategy]:
assert issubclass(cls, LongTextStrategy)
cls.name = name
preregistered_strategies.append(cls)
return cls
return decorator
class LongTextStrategy(metaclass=abc.ABCMeta): class LongTextStrategy(metaclass=abc.ABCMeta):
"""长文本处理策略抽象类
"""
name: str
ap: app.Application ap: app.Application
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
@@ -20,4 +52,15 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
"""处理长文本
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法
Args:
message (str): 消息
query (core_entities.Query): 此次请求的上下文对象
Returns:
list[mirai.models.messages.MessageComponent]: 转换后的 YiriMirai 消息组件列表
"""
return [] return []

View File

@@ -9,6 +9,7 @@ from ..platform import adapter as msadapter
class QueryPool: class QueryPool:
"""请求池请求获得调度进入pipeline之前保存在这里"""
query_id_counter: int = 0 query_id_counter: int = 0

View File

@@ -8,7 +8,7 @@ from ...plugin import events
@stage.stage_class("PreProcessor") @stage.stage_class("PreProcessor")
class PreProcessor(stage.PipelineStage): class PreProcessor(stage.PipelineStage):
"""预处理 """请求预处理阶段
""" """
async def process( async def process(
@@ -51,28 +51,6 @@ class PreProcessor(stage.PipelineStage):
query.prompt.messages = event_ctx.event.default_prompt query.prompt.messages = event_ctx.event.default_prompt
query.messages = event_ctx.event.prompt query.messages = event_ctx.event.prompt
# 根据模型max_tokens剪裁
max_tokens = min(query.use_model.max_tokens, self.ap.pipeline_cfg.data['submit-messages-tokens'])
test_messages = query.prompt.messages + query.messages + [query.user_message]
while await query.use_model.tokenizer.count_token(test_messages, query.use_model) > max_tokens:
# 前文都pop完了还是大于max_tokens由于prompt和user_messages不能删减报错
if len(query.prompt.messages) == 0:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice='输入内容过长,请减少情景预设或者输入内容长度',
console_notice='输入内容过长,请减少情景预设或者输入内容长度,或者增大配置文件中的 submit-messages-tokens 项但不能超过所用模型最大tokens数'
)
query.messages.pop(0) # pop第一个肯定是role=user的
# 继续pop到第二个role=user前一个
while len(query.messages) > 0 and query.messages[0].role != 'user':
query.messages.pop(0)
test_messages = query.prompt.messages + query.messages + [query.user_message]
return entities.StageProcessResult( return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query

View File

@@ -21,8 +21,6 @@ class ChatMessageHandler(handler.MessageHandler):
) -> typing.AsyncGenerator[entities.StageProcessResult, None]: ) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理 """处理
""" """
# 取session
# 取conversation
# 调API # 调API
# 生成器 # 生成器
@@ -41,7 +39,14 @@ class ChatMessageHandler(handler.MessageHandler):
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) mc = mirai.MessageChain(event_ctx.event.reply)
query.resp_messages.append(
llm_entities.Message(
role='plugin',
content=str(mc),
)
)
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,

View File

@@ -19,15 +19,16 @@ class CommandHandler(handler.MessageHandler):
"""处理 """处理
""" """
event_class = events.PersonCommandSent if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupCommandSent command_text = str(query.message_chain).strip()[1:]
privilege = 1 privilege = 1
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.system_cfg.data['admin-sessions']: if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.system_cfg.data['admin-sessions']:
privilege = 2 privilege = 2
spt = str(query.message_chain).strip().split(' ') spt = command_text.split(' ')
event_class = events.PersonCommandSent if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupCommandSent
event_ctx = await self.ap.plugin_mgr.emit_event( event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_class( event=event_class(
@@ -73,8 +74,6 @@ class CommandHandler(handler.MessageHandler):
session = await self.ap.sess_mgr.get_session(query) session = await self.ap.sess_mgr.get_session(query)
command_text = str(query.message_chain).strip()[1:]
async for ret in self.ap.cmd_mgr.execute( async for ret in self.ap.cmd_mgr.execute(
command_text=command_text, command_text=command_text,
query=query, query=query,

View File

@@ -11,6 +11,7 @@ from ...config import manager as cfg_mgr
@stage.stage_class("MessageProcessor") @stage.stage_class("MessageProcessor")
class Processor(stage.PipelineStage): class Processor(stage.PipelineStage):
"""请求实际处理阶段"""
cmd_handler: handler.MessageHandler cmd_handler: handler.MessageHandler

View File

@@ -1,11 +1,27 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import typing
from ...core import app from ...core import app
class ReteLimitAlgo(metaclass=abc.ABCMeta): preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
def algo_class(name: str):
def decorator(cls: typing.Type[ReteLimitAlgo]) -> typing.Type[ReteLimitAlgo]:
cls.name = name
preregistered_algos.append(cls)
return cls
return decorator
class ReteLimitAlgo(metaclass=abc.ABCMeta):
"""限流算法抽象类"""
name: str = None
ap: app.Application ap: app.Application
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
@@ -16,9 +32,27 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def require_access(self, launcher_type: str, launcher_id: int) -> bool: async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
"""进入处理流程
这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。
Args:
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
launcher_id (int): 请求者ID
Returns:
bool: 是否允许进入处理流程若返回false则直接丢弃该请求
"""
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def release_access(self, launcher_type: str, launcher_id: int): async def release_access(self, launcher_type: str, launcher_id: int):
"""退出处理流程
Args:
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
launcher_id (int): 请求者ID
"""
raise NotImplementedError raise NotImplementedError

View File

@@ -19,6 +19,7 @@ class SessionContainer:
self.records = {} self.records = {}
@algo.algo_class("fixwin")
class FixedWindowAlgo(algo.ReteLimitAlgo): class FixedWindowAlgo(algo.ReteLimitAlgo):
containers_lock: asyncio.Lock containers_lock: asyncio.Lock

View File

@@ -11,11 +11,24 @@ from ...core import entities as core_entities
@stage.stage_class("RequireRateLimitOccupancy") @stage.stage_class("RequireRateLimitOccupancy")
@stage.stage_class("ReleaseRateLimitOccupancy") @stage.stage_class("ReleaseRateLimitOccupancy")
class RateLimit(stage.PipelineStage): class RateLimit(stage.PipelineStage):
"""限速器控制阶段"""
algo: algo.ReteLimitAlgo algo: algo.ReteLimitAlgo
async def initialize(self): async def initialize(self):
self.algo = fixedwin.FixedWindowAlgo(self.ap)
algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo']
algo_class = None
for algo_cls in algo.preregistered_algos:
if algo_cls.name == algo_name:
algo_class = algo_cls
break
else:
raise ValueError(f'未知的限速算法: {algo_name}')
self.algo = algo_class(self.ap)
await self.algo.initialize() await self.algo.initialize()
async def process( async def process(

View File

@@ -29,7 +29,7 @@ class SendResponseBackStage(stage.PipelineStage):
await asyncio.sleep(random_delay) await asyncio.sleep(random_delay)
await self.ap.im_mgr.send( await self.ap.platform_mgr.send(
query.message_event, query.message_event,
query.resp_message_chain, query.resp_message_chain,
adapter=query.adapter adapter=query.adapter

View File

@@ -21,15 +21,13 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
async def initialize(self): async def initialize(self):
"""初始化检查器 """初始化检查器
""" """
self.rule_matchers = [
atbot.AtBotRule(self.ap),
prefix.PrefixRule(self.ap),
regexp.RegExpRule(self.ap),
random.RandomRespRule(self.ap),
]
for rule_matcher in self.rule_matchers: self.rule_matchers = []
await rule_matcher.initialize()
for rule_matcher in rule.preregisetered_rules:
rule_inst = rule_matcher(self.ap)
await rule_inst.initialize()
self.rule_matchers.append(rule_inst)
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:

View File

@@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import typing
import mirai import mirai
@@ -7,9 +8,20 @@ from ...core import app, entities as core_entities
from . import entities from . import entities
preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
def rule_class(name: str):
def decorator(cls: typing.Type[GroupRespondRule]) -> typing.Type[GroupRespondRule]:
cls.name = name
preregisetered_rules.append(cls)
return cls
return decorator
class GroupRespondRule(metaclass=abc.ABCMeta): class GroupRespondRule(metaclass=abc.ABCMeta):
"""群组响应规则的抽象类 """群组响应规则的抽象类
""" """
name: str
ap: app.Application ap: app.Application

View File

@@ -7,6 +7,7 @@ from .. import entities
from ....core import entities as core_entities from ....core import entities as core_entities
@rule_model.rule_class("at-bot")
class AtBotRule(rule_model.GroupRespondRule): class AtBotRule(rule_model.GroupRespondRule):
async def match( async def match(
@@ -19,6 +20,10 @@ class AtBotRule(rule_model.GroupRespondRule):
if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']: if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']:
message_chain.remove(mirai.At(query.adapter.bot_account_id)) message_chain.remove(mirai.At(query.adapter.bot_account_id))
if message_chain.has(mirai.At(query.adapter.bot_account_id)): # 回复消息时会at两次检查并删除重复的
message_chain.remove(mirai.At(query.adapter.bot_account_id))
return entities.RuleJudgeResult( return entities.RuleJudgeResult(
matching=True, matching=True,
replacement=message_chain, replacement=message_chain,

View File

@@ -5,6 +5,7 @@ from .. import entities
from ....core import entities as core_entities from ....core import entities as core_entities
@rule_model.rule_class("prefix")
class PrefixRule(rule_model.GroupRespondRule): class PrefixRule(rule_model.GroupRespondRule):
async def match( async def match(

View File

@@ -7,6 +7,7 @@ from .. import entities
from ....core import entities as core_entities from ....core import entities as core_entities
@rule_model.rule_class("random")
class RandomRespRule(rule_model.GroupRespondRule): class RandomRespRule(rule_model.GroupRespondRule):
async def match( async def match(

View File

@@ -7,6 +7,7 @@ from .. import entities
from ....core import entities as core_entities from ....core import entities as core_entities
@rule_model.rule_class("regexp")
class RegExpRule(rule_model.GroupRespondRule): class RegExpRule(rule_model.GroupRespondRule):
async def match( async def match(

View File

@@ -15,6 +15,7 @@ from .preproc import preproc
from .ratelimit import ratelimit from .ratelimit import ratelimit
# 请求处理阶段顺序
stage_order = [ stage_order = [
"GroupRespondRuleCheckStage", "GroupRespondRuleCheckStage",
"BanSessionCheckStage", "BanSessionCheckStage",

View File

@@ -29,6 +29,13 @@ class ResponseWrapper(stage.PipelineStage):
if query.resp_messages[-1].role == 'command': if query.resp_messages[-1].role == 'command':
query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content) query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
elif query.resp_messages[-1].role == 'plugin':
query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content)
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query

View File

@@ -14,6 +14,14 @@ preregistered_adapters: list[typing.Type[MessageSourceAdapter]] = []
def adapter_class( def adapter_class(
name: str name: str
): ):
"""消息平台适配器类装饰器
Args:
name (str): 适配器名称
Returns:
typing.Callable[[typing.Type[MessageSourceAdapter]], typing.Type[MessageSourceAdapter]]: 装饰器
"""
def decorator(cls: typing.Type[MessageSourceAdapter]) -> typing.Type[MessageSourceAdapter]: def decorator(cls: typing.Type[MessageSourceAdapter]) -> typing.Type[MessageSourceAdapter]:
cls.name = name cls.name = name
preregistered_adapters.append(cls) preregistered_adapters.append(cls)
@@ -22,15 +30,24 @@ def adapter_class(
class MessageSourceAdapter(metaclass=abc.ABCMeta): class MessageSourceAdapter(metaclass=abc.ABCMeta):
"""消息平台适配器基类"""
name: str name: str
bot_account_id: int bot_account_id: int
"""机器人账号ID需要在初始化时设置"""
config: dict config: dict
ap: app.Application ap: app.Application
def __init__(self, config: dict, ap: app.Application): def __init__(self, config: dict, ap: app.Application):
"""初始化适配器
Args:
config (dict): 对应的配置
ap (app.Application): 应用上下文
"""
self.config = config self.config = config
self.ap = ap self.ap = ap
@@ -40,7 +57,7 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
target_id: str, target_id: str,
message: mirai.MessageChain message: mirai.MessageChain
): ):
"""发送消息 """主动发送消息
Args: Args:
target_type (str): 目标类型,`person`或`group` target_type (str): 目标类型,`person`或`group`

View File

@@ -146,7 +146,7 @@ class PlatformManager:
if len(self.adapters) == 0: if len(self.adapters) == 0:
self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。') self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。')
async def send(self, event, msg, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True): async def send(self, event: mirai.MessageEvent, msg: mirai.MessageChain, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True):
if check_at_sender and self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage): if check_at_sender and self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage):
@@ -163,25 +163,6 @@ class PlatformManager:
quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False
) )
# 通知系统管理员
# TODO delete
# async def notify_admin(self, message: str):
# await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))]))
# async def notify_admin_message_chain(self, message: mirai.MessageChain):
# if self.ap.system_cfg.data['admin-sessions'] != []:
# admin_list = []
# for admin in self.ap.system_cfg.data['admin-sessions']:
# admin_list.append(admin)
# for adm in admin_list:
# self.adapter.send_message(
# adm.split("_")[0],
# adm.split("_")[1],
# message
# )
async def run(self): async def run(self):
try: try:
tasks = [] tasks = []

View File

@@ -40,7 +40,6 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
elif type(msg) is mirai.Voice: elif type(msg) is mirai.Voice:
msg_list.append(aiocqhttp.MessageSegment.record(msg.path)) msg_list.append(aiocqhttp.MessageSegment.record(msg.path))
elif type(msg) is forward.Forward: elif type(msg) is forward.Forward:
# print("aiocqhttp 暂不支持转发消息组件的转换,使用普通消息链发送")
for node in msg.node_list: for node in msg.node_list:
msg_list.extend(AiocqhttpMessageConverter.yiri2target(node.message_chain)[0]) msg_list.extend(AiocqhttpMessageConverter.yiri2target(node.message_chain)[0])
@@ -216,21 +215,28 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
self.ap = ap self.ap = ap
self.bot = aiocqhttp.CQHttp() if "access-token" in config:
self.bot = aiocqhttp.CQHttp(access_token=config["access-token"])
del self.config["access-token"]
else:
self.bot = aiocqhttp.CQHttp()
async def send_message( async def send_message(
self, target_type: str, target_id: str, message: mirai.MessageChain self, target_type: str, target_id: str, message: mirai.MessageChain
): ):
# TODO 实现发送消息 aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
return super().send_message(target_type, target_id, message)
if target_type == "group":
await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg)
elif target_type == "person":
await self.bot.send_private_msg(user_id=int(target_id), message=aiocq_msg)
async def reply_message( async def reply_message(
self, self,
message_source: mirai.MessageEvent, message_source: mirai.MessageEvent,
message: mirai.MessageChain, message: mirai.MessageChain,
quote_origin: bool = False, quote_origin: bool = False,
): ):
aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id) aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0] aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
if quote_origin: if quote_origin:

View File

@@ -24,6 +24,8 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
msg_list = message_chain.__root__ msg_list = message_chain.__root__
elif type(message_chain) is list: elif type(message_chain) is list:
msg_list = message_chain msg_list = message_chain
elif type(message_chain) is str:
msg_list = [mirai.Plain(message_chain)]
else: else:
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain))) raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))

View File

@@ -17,6 +17,7 @@ import botpy.types.message as botpy_message_type
from .. import adapter as adapter_model from .. import adapter as adapter_model
from ...pipeline.longtext.strategies import forward from ...pipeline.longtext.strategies import forward
from ...core import app from ...core import app
from ...config import manager as cfg_mgr
class OfficialGroupMessage(mirai.GroupMessage): class OfficialGroupMessage(mirai.GroupMessage):
@@ -34,52 +35,92 @@ cached_message_ids = {}
id_index = 0 id_index = 0
def save_msg_id(message_id: str) -> int: def save_msg_id(message_id: str) -> int:
"""保存消息id""" """保存消息id"""
global id_index, cached_message_ids global id_index, cached_message_ids
crt_index = id_index crt_index = id_index
id_index += 1 id_index += 1
cached_message_ids[str(crt_index)] = message_id cached_message_ids[str(crt_index)] = message_id
return crt_index return crt_index
cached_member_openids = {}
"""QQ官方 用户的id是字符串而YiriMirai的用户id是整数所以需要一个索引来进行转换"""
member_openid_index = 100 def char_to_value(char):
"""将单个字符转换为相应的数值。"""
def save_member_openid(member_openid: str) -> int: if '0' <= char <= '9':
"""保存用户id""" return ord(char) - ord('0')
global member_openid_index, cached_member_openids elif 'A' <= char <= 'Z':
return ord(char) - ord('A') + 10
if member_openid in cached_member_openids.values(): return ord(char) - ord('a') + 36
return list(cached_member_openids.keys())[list(cached_member_openids.values()).index(member_openid)]
crt_index = member_openid_index
member_openid_index += 1
cached_member_openids[str(crt_index)] = member_openid
return crt_index
cached_group_openids = {} def digest(s: str) -> int:
"""QQ官方 群组的id是字符串而YiriMirai的群组id是整数所以需要一个索引来进行转换""" """计算字符串的hash值。"""
# 取末尾的8位
sub_s = s[-10:]
group_openid_index = 1000 number = 0
base = 36
def save_group_openid(group_openid: str) -> int: for i in range(len(sub_s)):
"""保存群组id""" number = number * base + char_to_value(sub_s[i])
global group_openid_index, cached_group_openids
return number
K = typing.TypeVar("K")
V = typing.TypeVar("V")
class OpenIDMapping(typing.Generic[K, V]):
map: dict[K, V]
dump_func: typing.Callable
digest_func: typing.Callable[[K], V]
def __init__(self, map: dict[K, V], dump_func: typing.Callable, digest_func: typing.Callable[[K], V] = digest):
self.map = map
self.dump_func = dump_func
self.digest_func = digest_func
def __getitem__(self, key: K) -> V:
return self.map[key]
def __setitem__(self, key: K, value: V):
self.map[key] = value
self.dump_func()
def __contains__(self, key: K) -> bool:
return key in self.map
def __delitem__(self, key: K):
del self.map[key]
self.dump_func()
def getkey(self, value: V) -> K:
return list(self.map.keys())[list(self.map.values()).index(value)]
if group_openid in cached_group_openids.values(): def save_openid(self, key: K) -> V:
return list(cached_group_openids.keys())[list(cached_group_openids.values()).index(group_openid)]
if key in self.map:
crt_index = group_openid_index return self.map[key]
group_openid_index += 1
cached_group_openids[str(crt_index)] = group_openid value = self.digest_func(key)
return crt_index
self.map[key] = value
self.dump_func()
return value
class OfficialMessageConverter(adapter_model.MessageConverter): class OfficialMessageConverter(adapter_model.MessageConverter):
"""QQ 官方消息转换器""" """QQ 官方消息转换器"""
@staticmethod @staticmethod
def yiri2target(message_chain: mirai.MessageChain): def yiri2target(message_chain: mirai.MessageChain):
"""将 YiriMirai 的消息链转换为 QQ 官方消息""" """将 YiriMirai 的消息链转换为 QQ 官方消息"""
@@ -89,9 +130,13 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
msg_list = message_chain.__root__ msg_list = message_chain.__root__
elif type(message_chain) is list: elif type(message_chain) is list:
msg_list = message_chain msg_list = message_chain
elif type(message_chain) is str:
msg_list = [mirai.Plain(text=message_chain)]
else: else:
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain))) raise Exception(
"Unknown message type: " + str(message_chain) + str(type(message_chain))
)
offcial_messages: list[dict] = [] offcial_messages: list[dict] = []
""" """
{ {
@@ -108,36 +153,24 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
# 遍历并转换 # 遍历并转换
for component in msg_list: for component in msg_list:
if type(component) is mirai.Plain: if type(component) is mirai.Plain:
offcial_messages.append({ offcial_messages.append({"type": "text", "content": component.text})
"type": "text",
"content": component.text
})
elif type(component) is mirai.Image: elif type(component) is mirai.Image:
if component.url is not None: if component.url is not None:
offcial_messages.append( offcial_messages.append({"type": "image", "content": component.url})
{
"type": "image",
"content": component.url
}
)
elif component.path is not None: elif component.path is not None:
offcial_messages.append( offcial_messages.append(
{ {"type": "file_image", "content": component.path}
"type": "file_image",
"content": component.path
}
) )
elif type(component) is mirai.At: elif type(component) is mirai.At:
offcial_messages.append( offcial_messages.append({"type": "at", "content": ""})
{
"type": "at",
"content": ""
}
)
elif type(component) is mirai.AtAll: elif type(component) is mirai.AtAll:
print("上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。") print(
"上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
)
elif type(component) is mirai.Voice: elif type(component) is mirai.Voice:
print("上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。") print(
"上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
)
elif type(component) is forward.Forward: elif type(component) is forward.Forward:
# 转发消息 # 转发消息
yiri_forward_node_list = component.node_list yiri_forward_node_list = component.node_list
@@ -146,22 +179,33 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
for yiri_forward_node in yiri_forward_node_list: for yiri_forward_node in yiri_forward_node_list:
try: try:
message_chain = yiri_forward_node.message_chain message_chain = yiri_forward_node.message_chain
# 平铺 # 平铺
offcial_messages.extend(OfficialMessageConverter.yiri2target(message_chain)) offcial_messages.extend(
OfficialMessageConverter.yiri2target(message_chain)
)
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return offcial_messages return offcial_messages
@staticmethod @staticmethod
def extract_message_chain_from_obj(message: typing.Union[botpy_message.Message, botpy_message.DirectMessage], message_id: str = None, bot_account_id: int = 0) -> mirai.MessageChain: def extract_message_chain_from_obj(
message: typing.Union[botpy_message.Message, botpy_message.DirectMessage],
message_id: str = None,
bot_account_id: int = 0,
) -> mirai.MessageChain:
yiri_msg_list = [] yiri_msg_list = []
# 存id # 存id
yiri_msg_list.append(mirai.models.message.Source(id=save_msg_id(message_id), time=datetime.datetime.now())) yiri_msg_list.append(
mirai.models.message.Source(
id=save_msg_id(message_id), time=datetime.datetime.now()
)
)
if type(message) is not botpy_message.DirectMessage: if type(message) is not botpy_message.DirectMessage:
yiri_msg_list.append(mirai.At(target=bot_account_id)) yiri_msg_list.append(mirai.At(target=bot_account_id))
@@ -177,7 +221,9 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
if attachment.content_type == "image": if attachment.content_type == "image":
yiri_msg_list.append(mirai.Image(url=attachment.url)) yiri_msg_list.append(mirai.Image(url=attachment.url))
else: else:
logging.warning("不支持的附件类型:" + attachment.content_type + ",忽略此附件。") logging.warning(
"不支持的附件类型:" + attachment.content_type + ",忽略此附件。"
)
content = re.sub(r"<@!\d+>", "", str(message.content)) content = re.sub(r"<@!\d+>", "", str(message.content))
if content.strip() != "": if content.strip() != "":
@@ -186,29 +232,40 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
chain = mirai.MessageChain(yiri_msg_list) chain = mirai.MessageChain(yiri_msg_list)
return chain return chain
class OfficialEventConverter(adapter_model.EventConverter): class OfficialEventConverter(adapter_model.EventConverter):
"""事件转换器""" """事件转换器"""
@staticmethod
def yiri2target(event: typing.Type[mirai.Event]): member_openid_mapping: OpenIDMapping[str, int]
group_openid_mapping: OpenIDMapping[str, int]
def __init__(self, member_openid_mapping: OpenIDMapping[str, int], group_openid_mapping: OpenIDMapping[str, int]):
self.member_openid_mapping = member_openid_mapping
self.group_openid_mapping = group_openid_mapping
def yiri2target(self, event: typing.Type[mirai.Event]):
if event == mirai.GroupMessage: if event == mirai.GroupMessage:
return botpy_message.Message return botpy_message.Message
elif event == mirai.FriendMessage: elif event == mirai.FriendMessage:
return botpy_message.DirectMessage return botpy_message.DirectMessage
else: else:
raise Exception("未支持转换的事件类型(YiriMirai -> Official): " + str(event)) raise Exception(
"未支持转换的事件类型(YiriMirai -> Official): " + str(event)
)
@staticmethod def target2yiri(
def target2yiri(event: typing.Union[botpy_message.Message, botpy_message.DirectMessage]) -> mirai.Event: self,
event: typing.Union[botpy_message.Message, botpy_message.DirectMessage]
) -> mirai.Event:
import mirai.models.entities as mirai_entities import mirai.models.entities as mirai_entities
if type(event) == botpy_message.Message: # 频道内,转群聊事件 if type(event) == botpy_message.Message: # 频道内,转群聊事件
permission = "MEMBER" permission = "MEMBER"
if '2' in event.member.roles: if "2" in event.member.roles:
permission = "ADMINISTRATOR" permission = "ADMINISTRATOR"
elif '4' in event.member.roles: elif "4" in event.member.roles:
permission = "OWNER" permission = "OWNER"
return mirai.GroupMessage( return mirai.GroupMessage(
@@ -219,15 +276,25 @@ class OfficialEventConverter(adapter_model.EventConverter):
group=mirai_entities.Group( group=mirai_entities.Group(
id=event.channel_id, id=event.channel_id,
name=event.author.username, name=event.author.username,
permission=mirai_entities.Permission.Member permission=mirai_entities.Permission.Member,
),
special_title="",
join_timestamp=int(
datetime.datetime.strptime(
event.member.joined_at, "%Y-%m-%dT%H:%M:%S%z"
).timestamp()
), ),
special_title='',
join_timestamp=int(datetime.datetime.strptime(event.member.joined_at, "%Y-%m-%dT%H:%M:%S%z").timestamp()),
last_speak_timestamp=datetime.datetime.now().timestamp(), last_speak_timestamp=datetime.datetime.now().timestamp(),
mute_time_remaining=0, mute_time_remaining=0,
), ),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id), message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
time=int(datetime.datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S%z").timestamp()), event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
).timestamp()
),
) )
elif type(event) == botpy_message.DirectMessage: # 私聊,转私聊事件 elif type(event) == botpy_message.DirectMessage: # 私聊,转私聊事件
return mirai.FriendMessage( return mirai.FriendMessage(
@@ -236,12 +303,18 @@ class OfficialEventConverter(adapter_model.EventConverter):
nickname=event.author.username, nickname=event.author.username,
remark=event.author.username, remark=event.author.username,
), ),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id), message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
time=int(datetime.datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S%z").timestamp()), event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
).timestamp()
),
) )
elif type(event) == botpy_message.GroupMessage: elif type(event) == botpy_message.GroupMessage:
replacing_member_id = save_member_openid(event.author.member_openid) replacing_member_id = self.member_openid_mapping.save_openid(event.author.member_openid)
return OfficialGroupMessage( return OfficialGroupMessage(
sender=mirai_entities.GroupMember( sender=mirai_entities.GroupMember(
@@ -249,29 +322,36 @@ class OfficialEventConverter(adapter_model.EventConverter):
member_name=replacing_member_id, member_name=replacing_member_id,
permission="MEMBER", permission="MEMBER",
group=mirai_entities.Group( group=mirai_entities.Group(
id=save_group_openid(event.group_openid), id=self.group_openid_mapping.save_openid(event.group_openid),
name=replacing_member_id, name=replacing_member_id,
permission=mirai_entities.Permission.Member permission=mirai_entities.Permission.Member,
), ),
special_title='', special_title="",
join_timestamp=int(0), join_timestamp=int(0),
last_speak_timestamp=datetime.datetime.now().timestamp(), last_speak_timestamp=datetime.datetime.now().timestamp(),
mute_time_remaining=0, mute_time_remaining=0,
), ),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id), message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
time=int(datetime.datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S%z").timestamp()), event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
).timestamp()
),
) )
@adapter_model.adapter_class("qq-botpy") @adapter_model.adapter_class("qq-botpy")
class OfficialAdapter(adapter_model.MessageSourceAdapter): class OfficialAdapter(adapter_model.MessageSourceAdapter):
"""QQ 官方消息适配器""" """QQ 官方消息适配器"""
bot: botpy.Client = None bot: botpy.Client = None
bot_account_id: int = 0 bot_account_id: int = 0
message_converter: OfficialMessageConverter = OfficialMessageConverter() message_converter: OfficialMessageConverter
# event_handler: adapter_model.EventHandler = adapter_model.EventHandler() event_converter: OfficialEventConverter
cfg: dict = None cfg: dict = None
@@ -283,6 +363,11 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
ap: app.Application ap: app.Application
metadata: cfg_mgr.ConfigManager = None
member_openid_mapping: OpenIDMapping[str, int] = None
group_openid_mapping: OpenIDMapping[str, int] = None
def __init__(self, cfg: dict, ap: app.Application): def __init__(self, cfg: dict, ap: app.Application):
"""初始化适配器""" """初始化适配器"""
self.cfg = cfg self.cfg = cfg
@@ -290,86 +375,119 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
switchs = {} switchs = {}
for intent in cfg['intents']: for intent in cfg["intents"]:
switchs[intent] = True switchs[intent] = True
del cfg['intents'] del cfg["intents"]
intents = botpy.Intents(**switchs) intents = botpy.Intents(**switchs)
self.bot = botpy.Client(intents=intents) self.bot = botpy.Client(intents=intents)
async def send_message( async def send_message(
self, self, target_type: str, target_id: str, message: mirai.MessageChain
target_type: str,
target_id: str,
message: mirai.MessageChain
): ):
pass message_list = self.message_converter.yiri2target(message)
for msg in message_list:
args = {}
if msg["type"] == "text":
args["content"] = msg["content"]
elif msg["type"] == "image":
args["image"] = msg["content"]
elif msg["type"] == "file_image":
args["file_image"] = msg["content"]
else:
continue
if target_type == "group":
args["channel_id"] = str(target_id)
await self.bot.api.post_message(**args)
elif target_type == "person":
args["guild_id"] = str(target_id)
await self.bot.api.post_dms(**args)
async def reply_message( async def reply_message(
self, self,
message_source: mirai.MessageEvent, message_source: mirai.MessageEvent,
message: mirai.MessageChain, message: mirai.MessageChain,
quote_origin: bool = False quote_origin: bool = False,
): ):
message_list = self.message_converter.yiri2target(message) message_list = self.message_converter.yiri2target(message)
tasks = []
msg_seq = 1 msg_seq = 1
for msg in message_list: for msg in message_list:
args = {} args = {}
if msg['type'] == 'text': if msg["type"] == "text":
args['content'] = msg['content'] args["content"] = msg["content"]
elif msg['type'] == 'image': elif msg["type"] == "image":
args['image'] = msg['content'] args["image"] = msg["content"]
elif msg['type'] == 'file_image': elif msg["type"] == "file_image":
args['file_image'] = msg["content"] args["file_image"] = msg["content"]
else: else:
continue continue
if quote_origin: if quote_origin:
args['message_reference'] = botpy_message_type.Reference(message_id=cached_message_ids[str(message_source.message_chain.message_id)]) args["message_reference"] = botpy_message_type.Reference(
message_id=cached_message_ids[
if type(message_source) == mirai.GroupMessage: str(message_source.message_chain.message_id)
args['channel_id'] = str(message_source.sender.group.id) ]
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
await self.bot.api.post_message(**args)
elif type(message_source) == mirai.FriendMessage:
args['guild_id'] = str(message_source.sender.id)
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
await self.bot.api.post_dms(**args)
elif type(message_source) == OfficialGroupMessage:
# args['guild_id'] = str(message_source.sender.group.id)
# args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
# await self.bot.api.post_message(**args)
if 'image' in args or 'file_image' in args:
continue
args['group_openid'] = cached_group_openids[str(message_source.sender.group.id)]
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
args['msg_seq'] = msg_seq
msg_seq += 1
await self.bot.api.post_group_message(
**args
) )
if type(message_source) == mirai.GroupMessage:
args["channel_id"] = str(message_source.sender.group.id)
args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id)
]
await self.bot.api.post_message(**args)
elif type(message_source) == mirai.FriendMessage:
args["guild_id"] = str(message_source.sender.id)
args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id)
]
await self.bot.api.post_dms(**args)
elif type(message_source) == OfficialGroupMessage:
if "image" in args or "file_image" in args:
continue
args["group_openid"] = self.group_openid_mapping.getkey(
message_source.sender.group.id
)
args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id)
]
args["msg_seq"] = msg_seq
msg_seq += 1
await self.bot.api.post_group_message(**args)
async def is_muted(self, group_id: int) -> bool: async def is_muted(self, group_id: int) -> bool:
return False return False
def register_listener( def register_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] callback: typing.Callable[
[mirai.Event, adapter_model.MessageSourceAdapter], None
],
): ):
try: try:
async def wrapper(message: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage]): async def wrapper(
message: typing.Union[
botpy_message.Message,
botpy_message.DirectMessage,
botpy_message.GroupMessage,
]
):
self.cached_official_messages[str(message.id)] = message self.cached_official_messages[str(message.id)] = message
await callback(OfficialEventConverter.target2yiri(message), self) await callback(self.event_converter.target2yiri(message), self)
for event_handler in event_handler_mapping[event_type]: for event_handler in event_handler_mapping[event_type]:
setattr(self.bot, event_handler, wrapper) setattr(self.bot, event_handler, wrapper)
@@ -380,15 +498,33 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] callback: typing.Callable[
[mirai.Event, adapter_model.MessageSourceAdapter], None
],
): ):
delattr(self.bot, event_handler_mapping[event_type]) delattr(self.bot, event_handler_mapping[event_type])
async def run_async(self): async def run_async(self):
self.ap.logger.info("运行 QQ 官方适配器")
await self.bot.start( self.metadata = self.ap.adapter_qq_botpy_meta
**self.cfg
self.member_openid_mapping = OpenIDMapping(
map=self.metadata.data["mapping"]["members"],
dump_func=self.metadata.dump_config_sync,
) )
self.group_openid_mapping = OpenIDMapping(
map=self.metadata.data["mapping"]["groups"],
dump_func=self.metadata.dump_config_sync,
)
self.message_converter = OfficialMessageConverter()
self.event_converter = OfficialEventConverter(
self.member_openid_mapping, self.group_openid_mapping
)
self.ap.logger.info("运行 QQ 官方适配器")
await self.bot.start(**self.cfg)
def kill(self) -> bool: def kill(self) -> bool:
return False return False

View File

@@ -9,10 +9,86 @@ from ..provider.tools import entities as tools_entities
from ..core import app from ..core import app
def register(
name: str,
description: str,
version: str,
author: str
) -> typing.Callable[[typing.Type[BasePlugin]], typing.Type[BasePlugin]]:
"""注册插件类
使用示例:
@register(
name="插件名称",
description="插件描述",
version="插件版本",
author="插件作者"
)
class MyPlugin(BasePlugin):
pass
"""
pass
def handler(
event: typing.Type[events.BaseEventModel]
) -> typing.Callable[[typing.Callable], typing.Callable]:
"""注册事件监听器
使用示例:
class MyPlugin(BasePlugin):
@handler(NormalMessageResponded)
async def on_normal_message_responded(self, ctx: EventContext):
pass
"""
pass
def llm_func(
name: str=None,
) -> typing.Callable:
"""注册内容函数
使用示例:
class MyPlugin(BasePlugin):
@llm_func("access_the_web_page")
async def _(self, query, url: str, brief_len: int):
\"""Call this function to search about the question before you answer any questions.
- Do not search through google.com at any time.
- If you need to search somthing, visit https://www.sogou.com/web?query=<something>.
- If user ask you to open a url (start with http:// or https://), visit it directly.
- Summary the plain content result by yourself, DO NOT directly output anything in the result you got.
Args:
url(str): url to visit
brief_len(int): max length of the plain text content, recommend 1024-4096, prefer 4096
Returns:
str: plain text content of the web page or error message(starts with 'error:')
\"""
"""
pass
class BasePlugin(metaclass=abc.ABCMeta): class BasePlugin(metaclass=abc.ABCMeta):
"""插件基类""" """插件基类"""
host: APIHost host: APIHost
"""API宿主"""
ap: app.Application
"""应用程序对象"""
def __init__(self, host: APIHost):
self.host = host
async def initialize(self):
"""初始化插件"""
pass
class APIHost: class APIHost:
@@ -61,8 +137,10 @@ class EventContext:
"""事件编号""" """事件编号"""
host: APIHost = None host: APIHost = None
"""API宿主"""
event: events.BaseEventModel = None event: events.BaseEventModel = None
"""此次事件的对象具体类型为handler注册时指定监听的类型可查看events.py中的定义"""
__prevent_default__ = False __prevent_default__ = False
"""是否阻止默认行为""" """是否阻止默认行为"""

View File

@@ -10,8 +10,10 @@ from ..provider import entities as llm_entities
class BaseEventModel(pydantic.BaseModel): class BaseEventModel(pydantic.BaseModel):
"""事件模型基类"""
query: typing.Union[core_entities.Query, None] query: typing.Union[core_entities.Query, None]
"""此次请求的query对象非请求过程的事件时为None"""
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True

View File

@@ -1,3 +1,7 @@
# 此模块已过时
# 请从 pkg.plugin.context 引入 BasePlugin, EventContext 和 APIHost
# 最早将于 v3.4 移除此模块
from . events import * from . events import *
from . context import EventContext, APIHost as PluginHost from . context import EventContext, APIHost as PluginHost

View File

@@ -7,6 +7,7 @@ from ..core import app
class PluginInstaller(metaclass=abc.ABCMeta): class PluginInstaller(metaclass=abc.ABCMeta):
"""插件安装器抽象类"""
ap: app.Application ap: app.Application

View File

@@ -12,6 +12,8 @@ from ...utils import pkgmgr
class GitHubRepoInstaller(installer.PluginInstaller): class GitHubRepoInstaller(installer.PluginInstaller):
"""GitHub仓库插件安装器
"""
def get_github_plugin_repo_label(self, repo_url: str) -> list[str]: def get_github_plugin_repo_label(self, repo_url: str) -> list[str]:
"""获取username, repo""" """获取username, repo"""

View File

@@ -9,7 +9,7 @@ from . import context, events
class PluginLoader(metaclass=abc.ABCMeta): class PluginLoader(metaclass=abc.ABCMeta):
"""插件加载器""" """插件加载器抽象类"""
ap: app.Application ap: app.Application

View File

@@ -5,11 +5,10 @@ import pkgutil
import importlib import importlib
import traceback import traceback
from CallingGPT.entities.namespace import get_func_schema
from .. import loader, events, context, models, host from .. import loader, events, context, models, host
from ...core import entities as core_entities from ...core import entities as core_entities
from ...provider.tools import entities as tools_entities from ...provider.tools import entities as tools_entities
from ...utils import funcschema
class PluginLoader(loader.PluginLoader): class PluginLoader(loader.PluginLoader):
@@ -29,6 +28,10 @@ class PluginLoader(loader.PluginLoader):
setattr(models, 'on', self.on) setattr(models, 'on', self.on)
setattr(models, 'func', self.func) setattr(models, 'func', self.func)
setattr(context, 'register', self.register)
setattr(context, 'handler', self.handler)
setattr(context, 'llm_func', self.llm_func)
def register( def register(
self, self,
name: str, name: str,
@@ -57,6 +60,8 @@ class PluginLoader(loader.PluginLoader):
return wrapper return wrapper
# 过时
# 最早将于 v3.4 版本移除
def on( def on(
self, self,
event: typing.Type[events.BaseEventModel] event: typing.Type[events.BaseEventModel]
@@ -83,6 +88,8 @@ class PluginLoader(loader.PluginLoader):
return wrapper return wrapper
# 过时
# 最早将于 v3.4 版本移除
def func( def func(
self, self,
name: str=None, name: str=None,
@@ -91,10 +98,11 @@ class PluginLoader(loader.PluginLoader):
self.ap.logger.debug(f'注册内容函数 {name}') self.ap.logger.debug(f'注册内容函数 {name}')
def wrapper(func: typing.Callable) -> typing.Callable: def wrapper(func: typing.Callable) -> typing.Callable:
function_schema = get_func_schema(func) function_schema = funcschema.get_func_schema(func)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name) function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
async def handler( async def handler(
plugin: context.BasePlugin,
query: core_entities.Query, query: core_entities.Query,
*args, *args,
**kwargs **kwargs
@@ -116,6 +124,46 @@ class PluginLoader(loader.PluginLoader):
return wrapper return wrapper
def handler(
self,
event: typing.Type[events.BaseEventModel]
) -> typing.Callable[[typing.Callable], typing.Callable]:
"""注册事件处理器"""
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
def wrapper(func: typing.Callable) -> typing.Callable:
self._current_container.event_handlers[event] = func
return func
return wrapper
def llm_func(
self,
name: str=None,
) -> typing.Callable:
"""注册内容函数"""
self.ap.logger.debug(f'注册内容函数 {name}')
def wrapper(func: typing.Callable) -> typing.Callable:
function_schema = funcschema.get_func_schema(func)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
llm_function = tools_entities.LLMFunction(
name=function_name,
human_desc='',
description=function_schema['description'],
enable=True,
parameters=function_schema['parameters'],
func=func,
)
self._current_container.content_functions.append(llm_function)
return func
return wrapper
async def _walk_plugin_path( async def _walk_plugin_path(
self, self,
module, module,

View File

@@ -5,11 +5,12 @@ import traceback
from ..core import app from ..core import app
from . import context, loader, events, installer, setting, models from . import context, loader, events, installer, setting, models
from .loaders import legacy from .loaders import classic
from .installers import github from .installers import github
class PluginManager: class PluginManager:
"""插件管理器"""
ap: app.Application ap: app.Application
@@ -25,7 +26,7 @@ class PluginManager:
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
self.ap = ap self.ap = ap
self.loader = legacy.PluginLoader(ap) self.loader = classic.PluginLoader(ap)
self.installer = github.GitHubRepoInstaller(ap) self.installer = github.GitHubRepoInstaller(ap)
self.setting = setting.SettingManager(ap) self.setting = setting.SettingManager(ap)
self.api_host = context.APIHost(ap) self.api_host = context.APIHost(ap)
@@ -51,6 +52,9 @@ class PluginManager:
for plugin in self.plugins: for plugin in self.plugins:
try: try:
plugin.plugin_inst = plugin.plugin_class(self.api_host) plugin.plugin_inst = plugin.plugin_class(self.api_host)
plugin.plugin_inst.ap = self.ap
plugin.plugin_inst.host = self.api_host
await plugin.plugin_inst.initialize()
except Exception as e: except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}') self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}')
self.ap.logger.exception(e) self.ap.logger.exception(e)
@@ -136,9 +140,8 @@ class PluginManager:
for plugin in self.plugins: for plugin in self.plugins:
if plugin.enabled: if plugin.enabled:
if event.__class__ in plugin.event_handlers: if event.__class__ in plugin.event_handlers:
self.ap.logger.debug(f'插件 {plugin.plugin_name} 触发事件 {event.__class__.__name__}')
emitted_plugins.append(plugin)
is_prevented_default_before_call = ctx.is_prevented_default() is_prevented_default_before_call = ctx.is_prevented_default()
try: try:
@@ -150,6 +153,8 @@ class PluginManager:
self.ap.logger.error(f'插件 {plugin.plugin_name} 触发事件 {event.__class__.__name__} 时发生错误: {e}') self.ap.logger.error(f'插件 {plugin.plugin_name} 触发事件 {event.__class__.__name__} 时发生错误: {e}')
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
emitted_plugins.append(plugin)
if not is_prevented_default_before_call and ctx.is_prevented_default(): if not is_prevented_default_before_call and ctx.is_prevented_default():
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行') self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行')

View File

@@ -1,3 +1,7 @@
# 此模块已过时,请引入 pkg.plugin.context 中的 register, handler 和 llm_func 来注册插件、事件处理函数和内容函数
# 各个事件模型请从 pkg.plugin.events 引入
# 最早将于 v3.4 移除此模块
from __future__ import annotations from __future__ import annotations
import typing import typing

View File

@@ -6,6 +6,7 @@ from . import context
class SettingManager: class SettingManager:
"""插件设置管理器"""
ap: app.Application ap: app.Application
@@ -15,10 +16,7 @@ class SettingManager:
self.ap = ap self.ap = ap
async def initialize(self): async def initialize(self):
self.settings = await cfg_mgr.load_json_config( self.settings = self.ap.plugin_setting_meta
'plugins/plugins.json',
'templates/plugin-settings.json'
)
async def sync_setting( async def sync_setting(
self, self,

View File

@@ -20,15 +20,22 @@ class ToolCall(pydantic.BaseModel):
class Message(pydantic.BaseModel): class Message(pydantic.BaseModel):
role: str # user, system, assistant, tool, command """消息"""
role: str # user, system, assistant, tool, command, plugin
"""消息的角色"""
name: typing.Optional[str] = None name: typing.Optional[str] = None
"""名称,仅函数调用返回时设置"""
content: typing.Optional[str] = None content: typing.Optional[str] = None
"""内容"""
function_call: typing.Optional[FunctionCall] = None function_call: typing.Optional[FunctionCall] = None
"""函数调用不再受支持请使用tool_calls"""
tool_calls: typing.Optional[list[ToolCall]] = None tool_calls: typing.Optional[list[ToolCall]] = None
"""工具调用"""
tool_call_id: typing.Optional[str] = None tool_call_id: typing.Optional[str] = None

View File

@@ -0,0 +1,52 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from ...core import entities as core_entities
from .. import entities as llm_entities
preregistered_requesters: list[typing.Type[LLMAPIRequester]] = []
def requester_class(name: str):
def decorator(cls: typing.Type[LLMAPIRequester]) -> typing.Type[LLMAPIRequester]:
cls.name = name
preregistered_requesters.append(cls)
return cls
return decorator
class LLMAPIRequester(metaclass=abc.ABCMeta):
"""LLM API请求器
"""
name: str = None
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def request(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求API
对话前文可以从 query 对象中获取。
可以多次yield消息对象。
Args:
query (core_entities.Query): 本次请求的上下文对象
Yields:
pkg.provider.entities.Message: 返回消息对象
"""
raise NotImplementedError

View File

@@ -0,0 +1,82 @@
from __future__ import annotations
import typing
import traceback
import anthropic
from .. import api, entities, errors
from .. import api, entities, errors
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...tools import entities as tools_entities
@api.requester_class("anthropic-messages")
class AnthropicMessages(api.LLMAPIRequester):
"""Anthropic Messages API 请求器"""
client: anthropic.AsyncAnthropic
async def initialize(self):
self.client = anthropic.AsyncAnthropic(
api_key="",
base_url=self.ap.provider_cfg.data['requester']['anthropic-messages']['base-url'],
timeout=self.ap.provider_cfg.data['requester']['anthropic-messages']['timeout'],
proxies=self.ap.proxy_mgr.get_forward_proxies()
)
async def request(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
self.client.api_key = query.use_model.token_mgr.get_token()
args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
args["model"] = query.use_model.name if query.use_model.model_name is None else query.use_model.model_name
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行
m.dict(exclude_none=True) for m in query.prompt.messages
] + [m.dict(exclude_none=True) for m in query.messages]
# 删除所有 role=system & content='' 的消息
req_messages = [
m for m in req_messages if not (m["role"] == "system" and m["content"].strip() == "")
]
# 检查是否有 role=system 的消息,若有,改为 role=user并在后面加一个 role=assistant 的消息
system_role_index = []
for i, m in enumerate(req_messages):
if m["role"] == "system":
system_role_index.append(i)
m["role"] = "user"
if system_role_index:
for i in system_role_index[::-1]:
req_messages.insert(i + 1, {"role": "assistant", "content": "Okay, I'll follow."})
# 忽略掉空消息,用户可能发送空消息,而上层未过滤
req_messages = [
m for m in req_messages if m["content"].strip() != ""
]
args["messages"] = req_messages
try:
resp = await self.client.messages.create(**args)
yield llm_entities.Message(
content=resp.content[0].text,
role=resp.role
)
except anthropic.AuthenticationError as e:
raise errors.RequesterError(f'api-key 无效: {e.message}')
except anthropic.BadRequestError as e:
raise errors.RequesterError(str(e.message))
except anthropic.NotFoundError as e:
if 'model: ' in str(e):
raise errors.RequesterError(f'模型无效: {e.message}')
else:
raise errors.RequesterError(f'请求地址无效: {e.message}')

View File

@@ -9,22 +9,31 @@ import openai
import openai.types.chat.chat_completion as chat_completion import openai.types.chat.chat_completion as chat_completion
import httpx import httpx
from pkg.provider.entities import Message
from .. import api, entities, errors from .. import api, entities, errors
from ....core import entities as core_entities from ....core import entities as core_entities, app
from ... import entities as llm_entities from ... import entities as llm_entities
from ...tools import entities as tools_entities from ...tools import entities as tools_entities
class OpenAIChatCompletion(api.LLMAPIRequester): @api.requester_class("openai-chat-completions")
class OpenAIChatCompletions(api.LLMAPIRequester):
"""OpenAI ChatCompletion API 请求器"""
client: openai.AsyncClient client: openai.AsyncClient
requester_cfg: dict
def __init__(self, ap: app.Application):
self.ap = ap
self.requester_cfg = self.ap.provider_cfg.data['requester']['openai-chat-completions']
async def initialize(self): async def initialize(self):
self.client = openai.AsyncClient( self.client = openai.AsyncClient(
api_key="", api_key="",
base_url=self.ap.provider_cfg.data['openai-config']['base_url'], base_url=self.requester_cfg['base-url'],
timeout=self.ap.provider_cfg.data['openai-config']['request-timeout'], timeout=self.requester_cfg['timeout'],
http_client=httpx.AsyncClient( http_client=httpx.AsyncClient(
proxies=self.ap.proxy_mgr.get_forward_proxies() proxies=self.ap.proxy_mgr.get_forward_proxies()
) )
@@ -55,7 +64,7 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
) -> llm_entities.Message: ) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token() self.client.api_key = use_model.token_mgr.get_token()
args = self.ap.provider_cfg.data['openai-config']['chat-completions-params'].copy() args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
if use_model.tool_call_supported: if use_model.tool_call_supported:
@@ -124,14 +133,17 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
req_messages.append(msg.dict(exclude_none=True)) req_messages.append(msg.dict(exclude_none=True))
async def request(self, query: core_entities.Query) -> AsyncGenerator[Message, None]: async def request(self, query: core_entities.Query) -> AsyncGenerator[llm_entities.Message, None]:
try: try:
async for msg in self._request(query): async for msg in self._request(query):
yield msg yield msg
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise errors.RequesterError('请求超时') raise errors.RequesterError('请求超时')
except openai.BadRequestError as e: except openai.BadRequestError as e:
raise errors.RequesterError(f'请求错误: {e.message}') if 'context_length_exceeded' in e.message:
raise errors.RequesterError(f'上文过长,请重置会话: {e.message}')
else:
raise errors.RequesterError(f'请求参数错误: {e.message}')
except openai.AuthenticationError as e: except openai.AuthenticationError as e:
raise errors.RequesterError(f'无效的 api-key: {e.message}') raise errors.RequesterError(f'无效的 api-key: {e.message}')
except openai.NotFoundError as e: except openai.NotFoundError as e:

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
from ....core import app
from . import chatcmpl
from .. import api
@api.requester_class("moonshot-chat-completions")
class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Moonshot ChatCompletion API 请求器"""
def __init__(self, ap: app.Application):
self.requester_cfg = ap.provider_cfg.data['requester']['moonshot-chat-completions']
self.ap = ap

View File

@@ -5,7 +5,7 @@ import typing
import pydantic import pydantic
from . import api from . import api
from . import token, tokenizer from . import token
class LLMModelInfo(pydantic.BaseModel): class LLMModelInfo(pydantic.BaseModel):
@@ -19,11 +19,7 @@ class LLMModelInfo(pydantic.BaseModel):
requester: api.LLMAPIRequester requester: api.LLMAPIRequester
tokenizer: 'tokenizer.LLMTokenizer'
tool_call_supported: typing.Optional[bool] = False tool_call_supported: typing.Optional[bool] = False
max_tokens: typing.Optional[int] = 2048
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True

View File

@@ -2,4 +2,4 @@ class RequesterError(Exception):
"""Base class for all Requester errors.""" """Base class for all Requester errors."""
def __init__(self, message: str): def __init__(self, message: str):
super().__init__("模型请求失败: "+message) super().__init__("模型请求失败: "+message)

View File

@@ -0,0 +1,109 @@
from __future__ import annotations
import aiohttp
from . import entities
from ...core import app
from . import token, api
from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
class ModelManager:
"""模型管理器"""
ap: app.Application
model_list: list[entities.LLMModelInfo]
requesters: dict[str, api.LLMAPIRequester]
token_mgrs: dict[str, token.TokenManager]
def __init__(self, ap: app.Application):
self.ap = ap
self.model_list = []
self.requesters = {}
self.token_mgrs = {}
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
"""通过名称获取模型
"""
for model in self.model_list:
if model.name == name:
return model
raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置")
async def initialize(self):
# 初始化token_mgr, requester
for k, v in self.ap.provider_cfg.data['keys'].items():
self.token_mgrs[k] = token.TokenManager(k, v)
for api_cls in api.preregistered_requesters:
api_inst = api_cls(self.ap)
await api_inst.initialize()
self.requesters[api_inst.name] = api_inst
# 尝试从api获取最新的模型信息
try:
async with aiohttp.ClientSession() as session:
async with session.request(
method="GET",
url=FETCH_MODEL_LIST_URL,
# 参数
params={
"version": self.ap.ver_mgr.get_current_version()
},
) as resp:
model_list = (await resp.json())['data']['list']
for model in model_list:
for index, local_model in enumerate(self.ap.llm_models_meta.data['list']):
if model['name'] == local_model['name']:
self.ap.llm_models_meta.data['list'][index] = model
break
else:
self.ap.llm_models_meta.data['list'].append(model)
await self.ap.llm_models_meta.dump_config()
except Exception as e:
self.ap.logger.debug(f'获取最新模型列表失败: {e}')
default_model_info: entities.LLMModelInfo = None
for model in self.ap.llm_models_meta.data['list']:
if model['name'] == 'default':
default_model_info = entities.LLMModelInfo(
name=model['name'],
model_name=None,
token_mgr=self.token_mgrs[model['token_mgr']],
requester=self.requesters[model['requester']],
tool_call_supported=model['tool_call_supported']
)
break
for model in self.ap.llm_models_meta.data['list']:
try:
model_name = model.get('model_name', default_model_info.model_name)
token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr
requester = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester
tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported)
model_info = entities.LLMModelInfo(
name=model['name'],
model_name=model_name,
token_mgr=token_mgr,
requester=requester,
tool_call_supported=tool_call_supported
)
self.model_list.append(model_info)
except Exception as e:
self.ap.logger.error(f"初始化模型 {model['name']} 失败: {e} ,请检查配置文件")

View File

@@ -6,6 +6,8 @@ import pydantic
class TokenManager(): class TokenManager():
"""鉴权 Token 管理器
"""
provider: str provider: str

View File

@@ -1,29 +0,0 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from ...core import entities as core_entities
from .. import entities as llm_entities
class LLMAPIRequester(metaclass=abc.ABCMeta):
"""LLM API请求器
"""
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def request(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求
"""
raise NotImplementedError

View File

@@ -1,242 +0,0 @@
from __future__ import annotations
from . import entities
from ...core import app
from .apis import chatcmpl
from . import token
from .tokenizers import tiktoken
class ModelManager:
ap: app.Application
model_list: list[entities.LLMModelInfo]
def __init__(self, ap: app.Application):
self.ap = ap
self.model_list = []
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
"""通过名称获取模型
"""
for model in self.model_list:
if model.name == name:
return model
raise ValueError(f"不支持模型: {name} , 请检查配置文件")
async def initialize(self):
openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap)
await openai_chat_completion.initialize()
openai_token_mgr = token.TokenManager(self.ap, list(self.ap.provider_cfg.data['openai-config']['api-keys']))
tiktoken_tokenizer = tiktoken.Tiktoken(self.ap)
model_list = [
entities.LLMModelInfo(
name="gpt-3.5-turbo",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=4096
),
entities.LLMModelInfo(
name="gpt-3.5-turbo-1106",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=16385
),
entities.LLMModelInfo(
name="gpt-3.5-turbo-16k",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=16385
),
entities.LLMModelInfo(
name="gpt-3.5-turbo-0613",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=4096
),
entities.LLMModelInfo(
name="gpt-3.5-turbo-16k-0613",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=16385
),
entities.LLMModelInfo(
name="gpt-3.5-turbo-0301",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=4096
)
]
self.model_list.extend(model_list)
gpt4_model_list = [
entities.LLMModelInfo(
name="gpt-4-0125-preview",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="gpt-4-turbo-preview",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="gpt-4-1106-preview",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="gpt-4-vision-preview",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="gpt-4",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=8192
),
entities.LLMModelInfo(
name="gpt-4-0613",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=8192
),
entities.LLMModelInfo(
name="gpt-4-32k",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=32768
),
entities.LLMModelInfo(
name="gpt-4-32k-0613",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=True,
tokenizer=tiktoken_tokenizer,
max_tokens=32768
)
]
self.model_list.extend(gpt4_model_list)
one_api_model_list = [
entities.LLMModelInfo(
name="OneAPI/SparkDesk",
model_name='SparkDesk',
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=8192
),
entities.LLMModelInfo(
name="OneAPI/chatglm_pro",
model_name='chatglm_pro',
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="OneAPI/chatglm_std",
model_name='chatglm_std',
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="OneAPI/chatglm_lite",
model_name='chatglm_lite',
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=128000
),
entities.LLMModelInfo(
name="OneAPI/qwen-v1",
model_name='qwen-v1',
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=6000
),
entities.LLMModelInfo(
name="OneAPI/qwen-plus-v1",
model_name='qwen-plus-v1',
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=30000
),
entities.LLMModelInfo(
name="OneAPI/ERNIE-Bot",
model_name='ERNIE-Bot',
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=2000
),
entities.LLMModelInfo(
name="OneAPI/ERNIE-Bot-turbo",
model_name='ERNIE-Bot-turbo',
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=7000
),
entities.LLMModelInfo(
name="OneAPI/gemini-pro",
model_name='gemini-pro',
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
tool_call_supported=False,
tokenizer=tiktoken_tokenizer,
max_tokens=30720
),
]
self.model_list.extend(one_api_model_list)

View File

@@ -1,29 +0,0 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from .. import entities as llm_entities
from . import entities
class LLMTokenizer(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
"""初始化分词器
"""
pass
@abc.abstractmethod
async def count_token(
self,
messages: list[llm_entities.Message],
model: entities.LLMModelInfo
) -> int:
pass

View File

@@ -1,28 +0,0 @@
from __future__ import annotations
import tiktoken
from .. import tokenizer
from ... import entities as llm_entities
from .. import entities
class Tiktoken(tokenizer.LLMTokenizer):
async def count_token(
self,
messages: list[llm_entities.Message],
model: entities.LLMModelInfo
) -> int:
try:
encoding = tiktoken.encoding_for_model(model.name)
except KeyError:
# print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
for message in messages:
num_tokens += len(encoding.encode(message.role))
num_tokens += len(encoding.encode(message.content if message.content is not None else ''))
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens

View File

@@ -6,6 +6,8 @@ from ...core import app, entities as core_entities
class SessionManager: class SessionManager:
"""会话管理器
"""
ap: app.Application ap: app.Application
@@ -39,6 +41,8 @@ class SessionManager:
return session return session
async def get_conversation(self, session: core_entities.Session) -> core_entities.Conversation: async def get_conversation(self, session: core_entities.Session) -> core_entities.Conversation:
"""获取对话或创建对话"""
if not session.conversations: if not session.conversations:
session.conversations = [] session.conversations = []
@@ -46,7 +50,7 @@ class SessionManager:
conversation = core_entities.Conversation( conversation = core_entities.Conversation(
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name), prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
messages=[], messages=[],
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['openai-config']['chat-completions-params']['model']), use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['model']),
use_funcs=await self.ap.tool_mgr.get_all_functions(), use_funcs=await self.ap.tool_mgr.get_all_functions(),
) )
session.conversations.append(conversation) session.conversations.append(conversation)

View File

@@ -10,5 +10,7 @@ class Prompt(pydantic.BaseModel):
"""供AI使用的Prompt""" """供AI使用的Prompt"""
name: str name: str
"""名称"""
messages: list[entities.Message] messages: list[entities.Message]
"""消息列表"""

View File

@@ -1,13 +1,27 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
import typing
from ...core import app from ...core import app
from . import entities from . import entities
preregistered_loaders: list[typing.Type[PromptLoader]] = []
def loader_class(name: str):
def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]:
cls.name = name
preregistered_loaders.append(cls)
return cls
return decorator
class PromptLoader(metaclass=abc.ABCMeta): class PromptLoader(metaclass=abc.ABCMeta):
"""Prompt加载器抽象类 """Prompt加载器抽象类
""" """
name: str
ap: app.Application ap: app.Application
@@ -22,7 +36,7 @@ class PromptLoader(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def load(self): async def load(self):
"""加载Prompt """加载Prompt存放到prompts列表中
""" """
raise NotImplementedError raise NotImplementedError

View File

@@ -8,14 +8,15 @@ from .. import entities
from ....provider import entities as llm_entities from ....provider import entities as llm_entities
@loader.loader_class("full-scenario")
class ScenarioPromptLoader(loader.PromptLoader): class ScenarioPromptLoader(loader.PromptLoader):
"""加载scenario目录下的json""" """加载scenario目录下的json"""
async def load(self): async def load(self):
"""加载Prompt """加载Prompt
""" """
for file in os.listdir("data/scenarios"): for file in os.listdir("data/scenario"):
with open("data/scenarios/{}".format(file), "r", encoding="utf-8") as f: with open("data/scenario/{}".format(file), "r", encoding="utf-8") as f:
file_str = f.read() file_str = f.read()
file_name = file.split(".")[0] file_name = file.split(".")[0]
file_json = json.loads(file_str) file_json = json.loads(file_str)

View File

@@ -6,6 +6,7 @@ from .. import entities
from ....provider import entities as llm_entities from ....provider import entities as llm_entities
@loader.loader_class("normal")
class SingleSystemPromptLoader(loader.PromptLoader): class SingleSystemPromptLoader(loader.PromptLoader):
"""配置文件中的单条system prompt的prompt加载器 """配置文件中的单条system prompt的prompt加载器
""" """

View File

@@ -6,6 +6,8 @@ from .loaders import single, scenario
class PromptManager: class PromptManager:
"""Prompt管理器
"""
ap: app.Application ap: app.Application
@@ -18,14 +20,18 @@ class PromptManager:
async def initialize(self): async def initialize(self):
loader_map = { mode_name = self.ap.provider_cfg.data['prompt-mode']
"normal": single.SingleSystemPromptLoader,
"full_scenario": scenario.ScenarioPromptLoader
}
loader_cls = loader_map[self.ap.provider_cfg.data['prompt-mode']] loader_class = None
self.loader_inst: loader.PromptLoader = loader_cls(self.ap) for loader_cls in loader.preregistered_loaders:
if loader_cls.name == mode_name:
loader_class = loader_cls
break
else:
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')
self.loader_inst: loader.PromptLoader = loader_class(self.ap)
await self.loader_inst.initialize() await self.loader_inst.initialize()
await self.loader_inst.load() await self.loader_inst.load()

View File

@@ -5,6 +5,7 @@ import traceback
from ...core import app, entities as core_entities from ...core import app, entities as core_entities
from . import entities from . import entities
from ...plugin import context as plugin_context
class ToolManager: class ToolManager:
@@ -28,6 +29,15 @@ class ToolManager:
return function return function
return None return None
async def get_function_and_plugin(self, name: str) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]:
"""获取函数和插件
"""
for plugin in self.ap.plugin_mgr.plugins:
for function in plugin.content_functions:
if function.name == name:
return function, plugin
return None, None
async def get_all_functions(self) -> list[entities.LLMFunction]: async def get_all_functions(self) -> list[entities.LLMFunction]:
"""获取所有函数 """获取所有函数
""" """
@@ -68,7 +78,7 @@ class ToolManager:
try: try:
function = await self.get_function(name) function, plugin = await self.get_function_and_plugin(name)
if function is None: if function is None:
return None return None
@@ -79,7 +89,7 @@ class ToolManager:
**parameters **parameters
} }
return await function.func(**parameters) return await function.func(plugin, **parameters)
except Exception as e: except Exception as e:
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}') self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}')
traceback.print_exc() traceback.print_exc()

File diff suppressed because one or more lines are too long

116
pkg/utils/funcschema.py Normal file
View File

@@ -0,0 +1,116 @@
import sys
import re
import inspect
def get_func_schema(function: callable) -> dict:
"""
Return the data schema of a function.
{
"function": function,
"description": "function description",
"parameters": {
"type": "object",
"properties": {
"parameter_a": {
"type": "str",
"description": "parameter_a description"
},
"parameter_b": {
"type": "int",
"description": "parameter_b description"
},
"parameter_c": {
"type": "str",
"description": "parameter_c description",
"enum": ["a", "b", "c"]
},
},
"required": ["parameter_a", "parameter_b"]
}
}
"""
func_doc = function.__doc__
# Google Style Docstring
if func_doc is None:
raise Exception("Function {} has no docstring.".format(function.__name__))
func_doc = func_doc.strip().replace(' ','').replace('\t', '')
# extract doc of args from docstring
doc_spt = func_doc.split('\n\n')
desc = doc_spt[0]
args = doc_spt[1] if len(doc_spt) > 1 else ""
returns = doc_spt[2] if len(doc_spt) > 2 else ""
# extract args
# delete the first line of args
arg_lines = args.split('\n')[1:]
arg_doc_list = re.findall(r'(\w+)(\((\w+)\))?:\s*(.*)', args)
args_doc = {}
for arg_line in arg_lines:
doc_tuple = re.findall(r'(\w+)(\(([\w\[\]]+)\))?:\s*(.*)', arg_line)
if len(doc_tuple) == 0:
continue
args_doc[doc_tuple[0][0]] = doc_tuple[0][3]
# extract returns
return_doc_list = re.findall(r'(\w+):\s*(.*)', returns)
params = enumerate(inspect.signature(function).parameters.values())
parameters = {
"type": "object",
"required": [],
"properties": {},
}
for i, param in params:
# 排除 self, query
if param.name in ['self', 'query']:
continue
param_type = param.annotation.__name__
type_name_mapping = {
"str": "string",
"int": "integer",
"float": "number",
"bool": "boolean",
"list": "array",
"dict": "object",
}
if param_type in type_name_mapping:
param_type = type_name_mapping[param_type]
parameters['properties'][param.name] = {
"type": param_type,
"description": args_doc[param.name],
}
# add schema for array
if param_type == "array":
# extract type of array, the int of list[int]
# use re
array_type_tuple = re.findall(r'list\[(\w+)\]', str(param.annotation))
array_type = 'string'
if len(array_type_tuple) > 0:
array_type = array_type_tuple[0]
if array_type in type_name_mapping:
array_type = type_name_mapping[array_type]
parameters['properties'][param.name]["items"] = {
"type": array_type,
}
if param.default is inspect.Parameter.empty:
parameters["required"].append(param.name)
return {
"function": function,
"description": desc,
"parameters": parameters,
}

View File

@@ -4,4 +4,8 @@ import sys
def get_platform() -> str: def get_platform() -> str:
"""获取当前平台""" """获取当前平台"""
# 检查是不是在 docker 里
if os.path.exists('/.dockerenv'):
return 'docker'
return sys.platform return sys.platform

View File

@@ -7,6 +7,9 @@ from ..core import app
class ProxyManager: class ProxyManager:
"""代理管理器
"""
ap: app.Application ap: app.Application
forward_proxies: dict[str, str] forward_proxies: dict[str, str]

View File

@@ -10,6 +10,8 @@ from . import constants
class VersionManager: class VersionManager:
"""版本管理器
"""
ap: app.Application ap: app.Application

View File

@@ -1,12 +1,12 @@
requests requests
openai>1.0.0 openai>1.0.0
anthropic
colorlog~=6.6.0 colorlog~=6.6.0
yiri-mirai-rc yiri-mirai-rc
aiocqhttp aiocqhttp
qq-botpy qq-botpy
nakuru-project-idk nakuru-project-idk
Pillow Pillow
CallingGPT
tiktoken tiktoken
PyYaml PyYaml
aiohttp aiohttp

View File

@@ -1,26 +1,8 @@
[ [
{ {
"id": 2, "id": 6,
"time": "2023-08-01 10:49:26", "time": "2024-03-08 22:30:00",
"timestamp": 1690858166, "timestamp": 1709908200,
"content": "现已支持GPT函数调用功能欢迎了解https://github.com/RockChinQ/QChatGPT/wiki/%E6%8F%92%E4%BB%B6%E4%BD%BF%E7%94%A8-%E5%86%85%E5%AE%B9%E5%87%BD%E6%95%B0" "content": "QChatGPT 3.x 已发布,若您仍在使用不再维护的 2.x 版本,请尽快迁移至 3.x 版本https://github.com/RockChinQ/QChatGPT/discussions/690"
},
{
"id": 3,
"time": "2023-11-10 12:20:09",
"timestamp": 1699590009,
"content": "OpenAI 库1.0版本已发行,若出现 OpenAI 调用问题,请更新 QChatGPT 版本。详见项目主页https://github.com/RockChinQ/QChatGPT"
},
{
"id": 4,
"time": "2023-11-13 18:02:39",
"timestamp": 1699869759,
"content": "近期 OpenAI 接口改动频繁正在积极适配并添加新功能请尽快更新到最新版本更新方式https://github.com/RockChinQ/QChatGPT/discussions/595"
},
{
"id": 5,
"time": "2023-12-07 9:20:00",
"timestamp": 1701912000,
"content": "QChatGPT 一周年啦感谢大家的选择和支持RockChinQ 在此衷心感谢素未谋面但又至关重要的你们每一个人,愿 AI 与我们同在欢迎前往https://github.com/RockChinQ/QChatGPT/discussions/627 参与讨论。"
} }
] ]

Some files were not shown because too many files have changed in this diff Show More