Compare commits

...

121 Commits

Author SHA1 Message Date
Junyan Qin
0094056def chore: release v3.4.0.1 2024-11-22 23:55:24 +08:00
Junyan Qin
a9f305a1c6 feat: 添加对pydantic v1的兼容性 2024-11-22 23:37:46 +08:00
Junyan Qin
e8cc048901 deps: bump pydantic to v2 2024-11-22 23:29:12 +08:00
Junyan Qin
05da43f606 chore: 更新模型信息 2024-11-22 22:33:05 +08:00
Junyan Qin
a81faa7d8e fix: gitee ai 配置 schema 2024-11-22 20:01:35 +08:00
Junyan Qin
18ba7d1da7 Merge pull request #929 from RockChinQ/feat/gitee-ai
feat: 添加对 Gitee AI 的支持
2024-11-22 19:59:25 +08:00
Junyan Qin
875adfcbaa feat: 添加对 Gitee AI 的支持 2024-11-21 23:28:19 +08:00
Junyan Qin
6e9c213893 fix: 登录失败时无提示 2024-11-20 21:03:02 +08:00
Junyan Qin
753066ccb9 fix: webui 访问提示在Windows上的编码问题 2024-11-19 19:33:58 +08:00
Junyan Qin
8b36782c25 chore: 更新 docker-compose.yaml 2024-11-18 23:31:49 +08:00
Junyan Qin
da9dde6bd2 doc: update README 2024-11-17 21:30:53 +08:00
Junyan Qin
07f6e69b93 doc: update README.md 2024-11-17 21:11:21 +08:00
Junyan Qin
31a7503df3 chore: release v3.4.0 2024-11-17 20:48:45 +08:00
Junyan Qin
11db8d8d17 Merge pull request #904 from RockChinQ/version/3.4.0
Version/3.4.0
2024-11-17 20:47:25 +08:00
Junyan Qin
93ee8d51bc Merge branch 'master' into version/3.4.0 2024-11-17 20:45:24 +08:00
Junyan Qin
83e80f324e perf: webui 文件存在性检查 2024-11-17 20:43:40 +08:00
Junyan Qin
c51eac717e doc: 修复 README 的死链 2024-11-17 20:37:42 +08:00
Junyan Qin
db7d5dcce3 chore: 替换多处 qchatgpt.rockchin.top 2024-11-17 20:35:39 +08:00
Junyan Qin
0d25578e22 perf: 配置文件页放到单独组件 2024-11-17 19:54:07 +08:00
Junyan Qin
1a457be823 Merge pull request #921 from RockChinQ/feat/authenticating
Feat: 用户鉴权
2024-11-17 19:13:11 +08:00
Junyan Qin
20e3edba8f feat: 用户账户系统 2024-11-17 19:11:44 +08:00
Junyan Qin
036c2182a5 chore: 修改aiocqhttp适配器默认端口为2280 2024-11-17 10:18:56 +08:00
Junyan Qin
6238f430e8 Merge pull request #900 from RockChinQ/feat/webui
Feat: webui
2024-11-16 19:27:35 +08:00
Junyan Qin
9fc891ec01 chore: Hello LangBot ! 2024-11-16 17:57:39 +08:00
Junyan Qin
491d977d9e ci: fix 2024-11-16 17:47:50 +08:00
Junyan Qin
9a4bcda9bc ci: 添加 build-artifacts 工作流在 release 分布时执行 2024-11-16 17:44:02 +08:00
Junyan Qin
2c2374a763 ci: fix 2024-11-16 17:39:34 +08:00
Junyan Qin
a76e0b287e ci: typo 2024-11-16 17:36:33 +08:00
Junyan Qin
1d6f1e3c7c fix: chore 2024-11-16 17:34:15 +08:00
Junyan Qin
896fd982a1 ci: upload artifacts 2024-11-16 17:33:24 +08:00
Junyan Qin
c031ab20da Merge pull request #920 from RockChinQ/feat/lifetime-controlling
Feat: 生命周期和热重载
2024-11-16 17:19:42 +08:00
Junyan Qin
318b6e6bf1 typo: minor fix 2024-11-16 17:16:40 +08:00
Junyan Qin
ca3999d251 feat: 插件文件更改热重载 2024-11-16 16:45:13 +08:00
Junyan Qin
658eb278c4 refactor: 重构部分插件管理逻辑 2024-11-16 16:13:02 +08:00
Junyan Qin
bb219889e5 feat: 消息平台热重载 2024-11-16 12:40:57 +08:00
Junyan Qin
3239c9ec3f feat: 彻底移除 yirimirai 2024-11-15 20:03:49 +08:00
Junyan Qin
16153dc573 perf: 设置页标题改为小写 2024-11-12 20:04:03 +08:00
Junyan Qin
e0d9a295ab perf: 优化部分 UI 显示 2024-11-12 18:57:43 +08:00
Junyan Qin
eabdda5eb1 feat: 改为 WebHashHistory 2024-11-12 18:29:37 +08:00
Junyan Qin
43f45f9184 feat: 修改 apibase 2024-11-12 18:14:53 +08:00
Junyan Qin
7c19785a17 fix: http_proxy 环境变量为空检查 2024-11-12 17:56:59 +08:00
Junyan Qin
78005f8b4e ci: 删除 refs-heads 2024-11-12 17:49:00 +08:00
Junyan Qin
0d4784d098 feat: 代理设置同步到环境变量 2024-11-11 19:12:30 +08:00
Junyan Qin
805454e037 chore: 更新 docker-compose.yaml 2024-11-10 16:37:44 +08:00
Junyan Qin
bf383bbf9c ci: fix 2024-11-10 16:29:40 +08:00
Junyan Qin
73ffd67792 ci: 添加构建 ci 2024-11-10 16:27:50 +08:00
Junyan Qin
54bbfc8eda perf: dashboard 添加图表更新提示 2024-11-10 15:38:06 +08:00
Junyan Qin
a3e234c979 perf: debug模式改为绿色 2024-11-10 12:03:34 +08:00
Junyan Qin
9336abff8b perf: 使用 json-editor-vue 作为json编辑器 2024-11-10 11:46:41 +08:00
Junyan Qin
0fe161cd7f pref: 优化日志显示 2024-11-10 11:04:29 +08:00
Junyan Qin
7cc55eab3e feat: 仪表盘基本数据 2024-11-10 00:05:47 +08:00
Junyan Qin
15482e398b feat: 插件删除功能 2024-11-09 23:25:26 +08:00
Junyan Qin
601fa0ac7f feat: 关于 LangBot 对话框 2024-11-09 22:44:56 +08:00
Junyan Qin
2819da5f2f fix: github下载未使用环境变量代理 2024-11-09 18:09:39 +08:00
Junyan Qin
3cb3562477 doc(README): fix deadlinks 2024-11-08 23:14:52 +08:00
Junyan Qin
cee205994f doc: update logo 2024-11-05 15:42:48 +08:00
Junyan Qin
e44df0a3dd feat: dashboard 基本组件 2024-11-04 21:54:02 +08:00
Junyan Qin
84a51cb26d feat: 插件安装支持 2024-11-04 00:01:07 +08:00
Junyan Qin
db02d9c126 feat: 完成任务列表功能 2024-11-03 23:22:33 +08:00
Junyan Qin
709b86b724 refactor: 使插件更新过程全异步 2024-11-03 22:27:31 +08:00
Junyan Qin
68184b0e47 Merge pull request #911 from RockChinQ/feat/trackable-async-tasks
Feat: 用户级任务系统
2024-11-01 22:42:11 +08:00
Junyan Qin
6d2a4c038d feat: 完成异步任务跟踪架构基础 2024-11-01 22:41:26 +08:00
Junyan Qin
2f05f5b456 feat: 添加任务列表框架 2024-10-24 18:28:57 +08:00
Junyan Qin
d5e3120350 chore: 确保 pydantic<2.0 2024-10-24 14:26:18 +08:00
Junyan Qin
a4589327a6 feat: 添加 python 版本检查 2024-10-22 18:17:09 +08:00
Junyan Qin
c151665419 feat: 添加任务管理模块 2024-10-22 18:09:18 +08:00
Junyan Qin
947790e8d1 Update README.md 2024-10-22 13:51:05 +08:00
Junyan Qin
26770439bb fix: 关闭编排对话框时错误的插件顺序 2024-10-21 19:18:40 +08:00
Junyan Qin
7da9171dde feat: 插件优先级更改功能 2024-10-20 22:20:35 +08:00
Junyan Qin
16b386eaf7 feat: 插件页展示功能 2024-10-19 18:38:01 +08:00
Junyan Qin
c330aab48b Merge pull request #902 from RockChinQ/feat/settings-form-render
Feat: 设置项可视化编辑器
2024-10-16 22:32:40 +08:00
Junyan Qin
5f998a0852 perf: settings 页面的一些提示 2024-10-16 22:24:15 +08:00
Junyan Qin
c3dfbb64a6 feat: 异常处理 2024-10-16 21:55:55 +08:00
Junyan Qin
3db52282b8 fix: 修复子字段值为空时导致字段丢失的问题 2024-10-16 16:08:58 +08:00
Junyan Qin
a313ae5f97 feat: 添加多个可视化编辑schema 2024-10-16 15:34:30 +08:00
Junyan Qin
18cce189a4 feat: 完成 pipeline 的可视化配置 2024-10-16 13:57:41 +08:00
Junyan Qin
fb308d576b fix(settings): 切换tab时的异步问题 2024-10-16 12:58:52 +08:00
Junyan Qin
8c976303a4 feat: system.json 的可视化编辑 2024-10-15 21:42:05 +08:00
Junyan Qin
12f1f3609d feat: 引入 vjsf 渲染表单 2024-10-15 16:16:39 +08:00
Junyan Qin
661fdeb6a1 perf: 重新切换到 settings tab 时加载之前编辑的内容 2024-10-15 14:28:06 +08:00
Junyan Qin
d52f9b9543 feat(settings): json 编辑器 2024-10-15 14:23:56 +08:00
Junyan Qin
7174742886 feat: settings 基础组件 2024-10-15 00:07:40 +08:00
Junyan Qin
cd0a8fb24b perf: 使内容背景稍微灰一些 2024-10-14 21:30:10 +08:00
Junyan Qin
1fbc92bc6d perf: 首页展示版本信息 2024-10-14 21:18:36 +08:00
Junyan Qin
231dca956d feat: 日志页面 2024-10-14 18:52:28 +08:00
RockChinQ
0dd74c825b feat: 前端基础框架 2024-10-13 22:34:35 +08:00
RockChinQ
9703fc0366 perf: 优化日志增量获取逻辑 2024-10-13 22:33:51 +08:00
RockChinQ
7c3557e943 feat: 持久化和 web 接口基础架构 2024-10-11 22:27:53 +08:00
RockChinQ
21f153e5c3 chore: webui 前端模板 2024-10-11 22:23:08 +08:00
Junyan Qin
ea6a0af5a7 Merge pull request #890 from RockChinQ/feat/more-platforms
Refactor: 移除 YiriMirai 组件
2024-09-26 14:41:03 +08:00
RockChinQ
c53ffaca6c fix: 处理插件 import mirai 时的兼容性问题 2024-09-26 14:38:18 +08:00
RockChinQ
3469515e04 feat: 删除代码中对 mirai 的引用 2024-09-26 13:01:45 +08:00
RockChinQ
e8da26cb8a fix: missing break 2024-09-26 11:23:37 +08:00
RockChinQ
1235fc1339 chore: release v3.3.1.1 2024-09-26 10:39:35 +08:00
Junyan Qin
47e308b99d Merge pull request #889 from YunZLu/add-check-role
Fix: Add Role Check to Prevent Validation Error
2024-09-26 09:31:25 +08:00
RockChinQ
fdba470e9a perf: 将 platform 的 组件导入包 __init__ 中 2024-09-26 00:28:57 +08:00
Junyan Qin
a1ccceefd2 Merge branch 'master' into feat/more-platforms 2024-09-26 00:26:17 +08:00
RockChinQ
1c4a700d92 refactor: 将 yirimirai 的组件集成进 platform 包 2024-09-26 00:23:03 +08:00
YunZL
81c2c3c0e5 Add Role Check to Prevent Validation Error 2024-09-23 23:25:54 +08:00
Junyan Qin
3c2db5097a Merge pull request #888 from Tigrex-Dai/master
fix: 添加了针对报错内容对event.sender中'role'的存在性检查
2024-09-22 16:50:55 +08:00
Tigrex Dai
ce56f79687 Update aiocqhttp.py
针对报错对"role"做存在性检查
2024-09-22 15:39:48 +08:00
RockChinQ
ee0d6dcdae chore: release v3.3.1.0 2024-09-08 15:14:24 +08:00
Junyan Qin
bcf1d92f73 Merge pull request #881 from RockChinQ/version/3.3.1.0
Version/3.3.1.0
2024-09-08 15:13:39 +08:00
RockChinQ
ffdec16ce6 docs: wiki 所有页面加上已弃用说明 2024-09-08 14:52:35 +08:00
RockChinQ
b2f6e84adc typo: 优化插件执行日志信息 2024-09-08 14:51:39 +08:00
Junyan Qin
f76c457e1f Update README.md 2024-09-03 20:07:41 +08:00
RockChinQ
80bd0a20df doc: 修复 README 中的logo图片 2024-08-30 14:48:23 +08:00
RockChinQ
efeaf73339 doc: 修改README图片链接 2024-08-30 11:13:04 +08:00
Junyan Qin
91b5100a24 Merge pull request #872 from RockChinQ/feat/config-file-api
Feat: 添加yaml配置文件的支持
2024-08-24 20:55:19 +08:00
RockChinQ
d1a06f4730 feat: 添加yaml配置文件的支持 2024-08-24 20:54:36 +08:00
Junyan Qin
b0b186e951 Merge pull request #871 from RockChinQ/feat/qq-c2c
Feat: 添加对 QQ 官方 API 私聊场景的支持
2024-08-24 17:04:41 +08:00
RockChinQ
4c8fedef6e feat: QQ官方api群聊和私聊支持图片 2024-08-24 17:01:35 +08:00
RockChinQ
718c221d01 feat: 支持官方机器人私信接口 2024-08-24 16:26:47 +08:00
Junyan Qin
077e77eee5 Merge pull request #869 from ligen131/lg/fix_image_format
fix: 发送正确的图片格式而不是默认的 `image/jpeg`
2024-08-24 15:47:55 +08:00
ligen131
b51ca06c7c fix: 发送正确的图片格式而不是默认的 image/jpeg 2024-08-19 00:00:29 +08:00
RockChinQ
2f092f4a87 chore: release v3.3.0.2 2024-08-01 23:14:07 +08:00
Junyan Qin
f1ff9c05c4 Merge pull request #864 from RockChinQ/version/3.3.0.2
fix: 消息忽略规则失效 (#854)
2024-08-01 23:12:33 +08:00
RockChinQ
c9c8603ccc fix: 消息忽略规则失效 (#854) 2024-08-01 23:01:28 +08:00
RockChinQ
47e281fb61 chore: release v3.3.0.1 2024-07-28 22:47:49 +08:00
RockChinQ
dc625647eb fix: ollama 依赖检查 2024-07-28 22:47:19 +08:00
RockChinQ
66cf1b05be chore: 优化issue和pr模板 2024-07-28 21:32:22 +08:00
176 changed files with 13716 additions and 2203 deletions

View File

@@ -3,61 +3,37 @@ description: 报错或漏洞请使用这个模板创建,不使用此模板创
title: "[Bug]: " title: "[Bug]: "
labels: ["bug?"] labels: ["bug?"]
body: body:
- type: dropdown
attributes:
label: 部署方式
description: "主程序使用的部署方式"
options:
- 手动部署
- 安装器部署
- 一键安装包部署
- Docker部署
validations:
required: true
- type: dropdown - type: dropdown
attributes: attributes:
label: 消息平台适配器 label: 消息平台适配器
description: "连接QQ使用的框架" description: "连接QQ使用的框架"
options: options:
- yiri-miraiMirai
- Nakurugo-cqhttp - Nakurugo-cqhttp
- aiocqhttp使用 OneBot 协议接入的) - aiocqhttp使用 OneBot 协议接入的)
- qq-botpyQQ官方API - qq-botpyQQ官方API
- yiri-miraiMirai
validations: validations:
required: false required: false
- type: input - type: input
attributes: attributes:
label: 系统环境 label: 运行环境
description: 操作系统、系统架构、**主机地理位置**,地理位置最好写清楚,涉及网络问题排查。 description: 操作系统、系统架构、**Python版本**、**主机地理位置**
placeholder: 例如: CentOS x64 中国大陆、Windows11 美国 placeholder: 例如: CentOS x64 Python 3.10.3、Docker 的直接写 Docker 就行
validations: validations:
required: true required: true
- type: input - type: input
attributes: attributes:
label: Python环境 label: LangBot 版本
description: 运行程序的Python版本 description: LangBot (QChatGPT) 版本
placeholder: 例如: Python 3.10 placeholder: 例如:v3.3.0,可以使用`!version`命令查看,或者到 pkg/utils/constants.py 查看
validations:
required: true
- type: input
attributes:
label: QChatGPT版本
description: QChatGPT版本号
placeholder: 例如: v2.6.0,可以使用`!version`命令查看
validations: validations:
required: true required: true
- type: textarea - type: textarea
attributes: attributes:
label: 异常情况 label: 异常情况
description: 完整描述异常情况,什么时候发生的、发生了什么,尽可能详细 description: 完整描述异常情况,什么时候发生的、发生了什么。**请附带日志信息。**
validations: validations:
required: true required: true
- type: textarea
attributes:
label: 日志信息
description: 请提供完整的 **登录框架 和 QChatGPT控制台**的相关日志信息(若有),不提供日志信息**无法**为您排查问题,请尽可能详细
validations:
required: false
- type: textarea - type: textarea
attributes: attributes:
label: 启用的插件 label: 启用的插件

View File

@@ -10,5 +10,4 @@ updates:
schedule: schedule:
interval: "weekly" interval: "weekly"
allow: allow:
- dependency-name: "yiri-mirai-rc"
- dependency-name: "openai" - dependency-name: "openai"

View File

@@ -2,24 +2,19 @@
实现/解决/优化的内容: 实现/解决/优化的内容:
### 事务 ## 检查清单
- [ ] 已阅读仓库[贡献指引](https://github.com/RockChinQ/QChatGPT/blob/master/CONTRIBUTING.md) ### PR 作者完成
- [ ] 已与维护者在issues或其他平台沟通此PR大致内容
## 以下内容可在起草PR后、合并PR前逐步完成 *请在方括号间写`x`以打勾
### 功能 - [ ] 阅读仓库[贡献指引](https://github.com/RockChinQ/LangBot/blob/master/CONTRIBUTING.md)了吗?
- [ ] 与项目所有者沟通过了吗?
- [ ] 我确定已自行测试所作的更改,确保功能符合预期。
- [ ] 已编写完善的配置文件字段说明(若有新增) ### 项目所有者完成
- [ ] 已编写面向用户的新功能说明(若有必要)
- [ ] 已测试新功能或更改
### 兼容性 - [ ] 相关 issues 链接了吗?
- [ ] 配置项写好了吗?迁移写好了吗?生效了吗?
- [ ] 已处理版本兼容性 - [ ] 依赖写到 requirements.txt 和 core/bootutils/deps.py 了吗
- [ ] 已处理插件兼容问题 - [ ] 文档编写了吗?
### 风险
可能导致或已知的问题:

24
.github/workflows/build-dev-image.yaml vendored Normal file
View File

@@ -0,0 +1,24 @@
name: Build Dev Image
on:
push:
workflow_dispatch:
jobs:
build-dev-image:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Generate Tag
id: generate_tag
run: |
# 获取分支名称,把/替换为-
echo ${{ github.ref }} | sed 's/refs\/heads\///g' | sed 's/\//-/g'
echo ::set-output name=tag::$(echo ${{ github.ref }} | sed 's/refs\/heads\///g' | sed 's/\//-/g')
- name: Login to Registry
run: docker login --username=${{ secrets.DOCKER_USERNAME }} --password ${{ secrets.DOCKER_PASSWORD }}
- name: Build Docker Image
run: |
docker buildx create --name mybuilder --use
docker build -t rockchin/langbot:${{ steps.generate_tag.outputs.tag }} . --push

View File

@@ -19,12 +19,6 @@ jobs:
export GITHUB_REF=${{ github.ref }} export GITHUB_REF=${{ github.ref }}
echo $GITHUB_REF echo $GITHUB_REF
fi fi
# - name: Check GITHUB_REF env
# run: echo $GITHUB_REF
# - name: Get version # 在 GitHub Actions 运行环境
# id: get_version
# if: (startsWith(env.GITHUB_REF, 'refs/tags/')||startsWith(github.ref, 'refs/tags/')) && startsWith(github.repository, 'RockChinQ/QChatGPT')
# run: export GITHUB_REF=${GITHUB_REF/refs\/tags\//}
- name: Check version - name: Check version
id: check_version id: check_version
run: | run: |
@@ -44,5 +38,5 @@ jobs:
run: docker login --username=${{ secrets.DOCKER_USERNAME }} --password ${{ secrets.DOCKER_PASSWORD }} run: docker login --username=${{ secrets.DOCKER_USERNAME }} --password ${{ secrets.DOCKER_PASSWORD }}
- name: Create Buildx - name: Create Buildx
run: docker buildx create --name mybuilder --use run: docker buildx create --name mybuilder --use
- name: Build # image name: rockchin/qchatgpt:<VERSION> - name: Build # image name: rockchin/langbot:<VERSION>
run: docker buildx build --platform linux/arm64,linux/amd64 -t rockchin/qchatgpt:${{ steps.check_version.outputs.version }} -t rockchin/qchatgpt:latest . --push run: docker buildx build --platform linux/arm64,linux/amd64 -t rockchin/langbot:${{ steps.check_version.outputs.version }} -t rockchin/langbot:latest . --push

View File

@@ -0,0 +1,52 @@
name: Build Release Artifacts
on:
workflow_dispatch:
## 发布release的时候会自动构建
release:
types: [published]
jobs:
build-artifacts:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Check version
id: check_version
run: |
echo $GITHUB_REF
# 如果是tag则去掉refs/tags/前缀
if [[ $GITHUB_REF == refs/tags/* ]]; then
echo "It's a tag"
echo $GITHUB_REF
echo $GITHUB_REF | awk -F '/' '{print $3}'
echo ::set-output name=version::$(echo $GITHUB_REF | awk -F '/' '{print $3}')
else
echo "It's not a tag"
echo $GITHUB_REF
echo ::set-output name=version::${GITHUB_REF}
fi
- name: Make Temp Directory
run: |
mkdir -p /tmp/langbot_build_web
cp -r . /tmp/langbot_build_web
- name: Setup Node
uses: actions/setup-node@v2
with:
node-version: '22'
- name: Build Web
run: |
cd /tmp/langbot_build_web/web
npm install
npm run build
- name: Package Output
run: |
cp -r /tmp/langbot_build_web/web/dist ./web
- name: Upload Artifact
uses: actions/upload-artifact@v4
with:
name: langbot-${{ steps.check_version.outputs.version }}-all
path: .

View File

@@ -1,43 +0,0 @@
name: Update Wiki
on:
push:
branches:
- master
paths:
- 'res/wiki/**'
jobs:
update-wiki:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Setup Git
run: |
git config --global user.name "GitHub Actions"
git config --global user.email "github-actions[bot]@users.noreply.github.com"
- name: Clone Wiki Repository
uses: actions/checkout@v2
with:
repository: RockChinQ/QChatGPT.wiki
path: wiki
- name: Delete old wiki content
run: |
rm -rf wiki/*
- name: Copy res/wiki content to wiki
run: |
cp -r res/wiki/* wiki/
- name: Check for changes
run: |
cd wiki
if git diff --quiet; then
echo "No changes to commit."
exit 0
fi
- name: Commit and Push Changes
run: |
cd wiki
git add .
git commit -m "Update wiki"
git push

8
.gitignore vendored
View File

@@ -3,9 +3,10 @@
__pycache__/ __pycache__/
database.db database.db
qchatgpt.log qchatgpt.log
langbot.log
/banlist.py /banlist.py
plugins/ /plugins/
!plugins/__init__.py !/plugins/__init__.py
/revcfg.py /revcfg.py
prompts/ prompts/
logs/ logs/
@@ -34,4 +35,5 @@ bard.json
res/instance_id.json res/instance_id.json
.DS_Store .DS_Store
/data /data
botpy.log* botpy.log*
/poc

View File

@@ -1,8 +1,19 @@
FROM node:22-alpine AS node
WORKDIR /app
COPY web ./web
RUN cd web && npm install && npm run build
FROM python:3.10.13-slim FROM python:3.10.13-slim
WORKDIR /app WORKDIR /app
COPY . . COPY . .
COPY --from=node /app/web/dist ./web/dist
RUN apt update \ RUN apt update \
&& apt install gcc -y \ && apt install gcc -y \
&& python -m pip install -r requirements.txt \ && python -m pip install -r requirements.txt \

View File

@@ -1,17 +1,14 @@
<p align="center"> <p align="center">
<img src="https://qchatgpt.rockchin.top/logo.png" alt="QChatGPT" width="180" /> <img src="https://docs.langbot.app/langbot-logo-0.5x.png" alt="QChatGPT" width="180" />
</p> </p>
<div align="center"> <div align="center">
# QChatGPT # LangBot
<a href="https://trendshift.io/repositories/6187" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6187" alt="RockChinQ%2FQChatGPT | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a> <a href="https://trendshift.io/repositories/6187" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6187" alt="RockChinQ%2FQChatGPT | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/RockChinQ/QChatGPT)](https://github.com/RockChinQ/QChatGPT/releases/latest) [![GitHub release (latest by date)](https://img.shields.io/github/v/release/RockChinQ/LangBot)](https://github.com/RockChinQ/LangBot/releases/latest)
<a href="https://hub.docker.com/repository/docker/rockchin/qchatgpt">
<img src="https://img.shields.io/docker/pulls/rockchin/qchatgpt?color=blue" alt="docker pull">
</a>
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.qchatgpt.rockchin.top%2Fapi%2Fv2%2Fview%2Frealtime%2Fcount_query%3Fminute%3D10080&query=%24.data.count&label=%E4%BD%BF%E7%94%A8%E9%87%8F%EF%BC%887%E6%97%A5%EF%BC%89) ![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.qchatgpt.rockchin.top%2Fapi%2Fv2%2Fview%2Frealtime%2Fcount_query%3Fminute%3D10080&query=%24.data.count&label=%E4%BD%BF%E7%94%A8%E9%87%8F%EF%BC%887%E6%97%A5%EF%BC%89)
![Wakapi Count](https://wakapi.rockchin.top/api/badge/RockChinQ/interval:any/project:QChatGPT) ![Wakapi Count](https://wakapi.rockchin.top/api/badge/RockChinQ/interval:any/project:QChatGPT)
<br/> <br/>
@@ -22,18 +19,15 @@
<a href="https://qm.qq.com/q/PClALFK242"> <a href="https://qm.qq.com/q/PClALFK242">
<img alt="Static Badge" src="https://img.shields.io/badge/%E7%A4%BE%E5%8C%BA%E7%BE%A4-619154800-purple"> <img alt="Static Badge" src="https://img.shields.io/badge/%E7%A4%BE%E5%8C%BA%E7%BE%A4-619154800-purple">
</a> </a>
<a href="https://codecov.io/gh/RockChinQ/QChatGPT" >
<img src="https://codecov.io/gh/RockChinQ/QChatGPT/graph/badge.svg?token=pjxYIL2kbC"/>
</a>
## 使用文档 ## 使用文档
<a href="https://qchatgpt.rockchin.top">项目主页</a> <a href="https://docs.langbot.app">项目主页</a>
<a href="https://qchatgpt.rockchin.top/posts/feature.html">功能介绍</a> <a href="https://docs.langbot.app/insight/intro.htmll">功能介绍</a>
<a href="https://qchatgpt.rockchin.top/posts/deploy/">部署文档</a> <a href="https://docs.langbot.app/insight/guide.html">部署文档</a>
<a href="https://qchatgpt.rockchin.top/posts/error/">常见问题</a> <a href="https://docs.langbot.app/usage/faq.html">常见问题</a>
<a href="https://qchatgpt.rockchin.top/posts/plugin/intro.html">插件介绍</a> <a href="https://docs.langbot.app/plugin/plugin-intro.html">插件介绍</a>
<a href="https://github.com/RockChinQ/QChatGPT/issues/new?assignees=&labels=%E7%8B%AC%E7%AB%8B%E6%8F%92%E4%BB%B6&projects=&template=submit-plugin.yml&title=%5BPlugin%5D%3A+%E8%AF%B7%E6%B1%82%E7%99%BB%E8%AE%B0%E6%96%B0%E6%8F%92%E4%BB%B6">提交插件</a> <a href="https://github.com/RockChinQ/LangBot/issues/new?assignees=&labels=%E7%8B%AC%E7%AB%8B%E6%8F%92%E4%BB%B6&projects=&template=submit-plugin.yml&title=%5BPlugin%5D%3A+%E8%AF%B7%E6%B1%82%E7%99%BB%E8%AE%B0%E6%96%B0%E6%8F%92%E4%BB%B6">提交插件</a>
## 相关链接 ## 相关链接
@@ -42,5 +36,5 @@
<a href="https://github.com/RockChinQ/qcg-center">遥测服务端源码</a> <a href="https://github.com/RockChinQ/qcg-center">遥测服务端源码</a>
<a href="https://github.com/the-lazy-me/QChatGPT-Wiki">官方文档储存库</a> <a href="https://github.com/the-lazy-me/QChatGPT-Wiki">官方文档储存库</a>
<img alt="回复效果(带有联网插件)" src="https://qchatgpt.rockchin.top/assets/image/QChatGPT-0516.png" width="500px"/> <img alt="回复效果(带有联网插件)" src="https://docs.langbot.app/QChatGPT-0516.png" width="500px"/>
</div> </div>

View File

@@ -1,10 +1,14 @@
version: "3" version: "3"
services: services:
qchatgpt: langbot:
image: rockchin/qchatgpt:latest image: rockchin/langbot:latest
container_name: langbot
volumes: volumes:
- ./data:/app/data - ./data:/app/data
- ./plugins:/app/plugins - ./plugins:/app/plugins
restart: on-failure restart: on-failure
# 根据具体环境配置网络 ports:
- 5300:5300 # 供 WebUI 使用
- 2280-2290:2280-2290 # 供消息平台适配器方向连接
# 根据具体环境配置网络

49
main.py
View File

@@ -1,19 +1,23 @@
# QChatGPT 终端启动入口 # LangBot 终端启动入口
# 在此层级解决依赖项检查。 # 在此层级解决依赖项检查。
# QChatGPT/main.py # LangBot/main.py
asciiart = r""" asciiart = r"""
___ ___ _ _ ___ ___ _____ _ ___ _
/ _ \ / __| |_ __ _| |_ / __| _ \_ _| | | __ _ _ _ __ _| _ ) ___| |_
| (_) | (__| ' \/ _` | _| (_ | _/ | | | |__/ _` | ' \/ _` | _ \/ _ \ _|
\__\_\\___|_||_\__,_|\__|\___|_| |_| |____\__,_|_||_\__, |___/\___/\__|
|___/
⭐️开源地址: https://github.com/RockChinQ/QChatGPT ⭐️开源地址: https://github.com/RockChinQ/LangBot
📖文档地址: https://q.rkcn.top 📖文档地址: https://docs.langbot.app
""" """
async def main_entry(): import asyncio
async def main_entry(loop: asyncio.AbstractEventLoop):
print(asciiart) print(asciiart)
import sys import sys
@@ -32,6 +36,12 @@ async def main_entry():
print("已自动安装缺失的依赖包,请重启程序。") print("已自动安装缺失的依赖包,请重启程序。")
sys.exit(0) sys.exit(0)
# 检查pydantic版本如果没有 pydantic.v1则把 pydantic 映射为 v1
import pydantic.version
if pydantic.version.VERSION < '2.0':
import pydantic
sys.modules['pydantic.v1'] = pydantic
# 检查配置文件 # 检查配置文件
from pkg.core.bootutils import files from pkg.core.bootutils import files
@@ -46,13 +56,20 @@ async def main_entry():
sys.exit(0) sys.exit(0)
from pkg.core import boot from pkg.core import boot
await boot.main() await boot.main(loop)
if __name__ == '__main__': if __name__ == '__main__':
import os import os
import sys
# 检查本目录是否有main.py且包含QChatGPT字符串 # 必须大于 3.10.1
if sys.version_info < (3, 10, 1):
print("需要 Python 3.10.1 及以上版本,当前 Python 版本为:", sys.version)
input("按任意键退出...")
exit(1)
# 检查本目录是否有main.py且包含LangBot字符串
invalid_pwd = False invalid_pwd = False
if not os.path.exists('main.py'): if not os.path.exists('main.py'):
@@ -60,13 +77,13 @@ if __name__ == '__main__':
else: else:
with open('main.py', 'r', encoding='utf-8') as f: with open('main.py', 'r', encoding='utf-8') as f:
content = f.read() content = f.read()
if "QChatGPT/main.py" not in content: if "LangBot/main.py" not in content:
invalid_pwd = True invalid_pwd = True
if invalid_pwd: if invalid_pwd:
print("请在QChatGPT项目根目录下以命令形式运行此程序。") print("请在 LangBot 项目根目录下以命令形式运行此程序。")
input("按任意键退出...") input("按任意键退出...")
exit(0) exit(1)
import asyncio loop = asyncio.new_event_loop()
asyncio.run(main_entry()) loop.run_until_complete(main_entry(loop))

View File

View File

@@ -0,0 +1,107 @@
from __future__ import annotations
import abc
import typing
import enum
import quart
from quart.typing import RouteCallable
from ....core import app
preregistered_groups: list[type[RouterGroup]] = []
"""RouterGroup 的预注册列表"""
def group_class(name: str, path: str) -> None:
"""注册一个 RouterGroup"""
def decorator(cls: typing.Type[RouterGroup]) -> typing.Type[RouterGroup]:
cls.name = name
cls.path = path
preregistered_groups.append(cls)
return cls
return decorator
class AuthType(enum.Enum):
"""认证类型"""
NONE = 'none'
USER_TOKEN = 'user-token'
class RouterGroup(abc.ABC):
name: str
path: str
ap: app.Application
quart_app: quart.Quart
def __init__(self, ap: app.Application, quart_app: quart.Quart) -> None:
self.ap = ap
self.quart_app = quart_app
@abc.abstractmethod
async def initialize(self) -> None:
pass
def route(self, rule: str, auth_type: AuthType = AuthType.USER_TOKEN, **options: typing.Any) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator
"""注册一个路由"""
def decorator(f: RouteCallable) -> RouteCallable:
nonlocal rule
rule = self.path + rule
async def handler_error(*args, **kwargs):
if auth_type == AuthType.USER_TOKEN:
# 从Authorization头中获取token
token = quart.request.headers.get('Authorization', '').replace('Bearer ', '')
if not token:
return self.http_status(401, -1, '未提供有效的用户令牌')
try:
user_email = await self.ap.user_service.verify_jwt_token(token)
# 检查f是否接受user_email参数
if 'user_email' in f.__code__.co_varnames:
kwargs['user_email'] = user_email
except Exception as e:
return self.http_status(401, -1, str(e))
try:
return await f(*args, **kwargs)
except Exception as e: # 自动 500
return self.http_status(500, -2, str(e))
new_f = handler_error
new_f.__name__ = (self.name + rule).replace('/', '__')
new_f.__doc__ = f.__doc__
self.quart_app.route(rule, **options)(new_f)
return f
return decorator
def success(self, data: typing.Any = None) -> quart.Response:
"""返回一个 200 响应"""
return quart.jsonify({
'code': 0,
'msg': 'ok',
'data': data,
})
def fail(self, code: int, msg: str) -> quart.Response:
"""返回一个异常响应"""
return quart.jsonify({
'code': code,
'msg': msg,
})
def http_status(self, status: int, code: int, msg: str) -> quart.Response:
"""返回一个指定状态码的响应"""
return self.fail(code, msg), status

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
import traceback
import quart
from .....core import app
from .. import group
@group.group_class('logs', '/api/v1/logs')
class LogsRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'])
async def _() -> str:
start_page_number = int(quart.request.args.get('start_page_number', 0))
start_offset = int(quart.request.args.get('start_offset', 0))
logs_str, end_page_number, end_offset = self.ap.log_cache.get_log_by_pointer(
start_page_number=start_page_number,
start_offset=start_offset
)
return self.success(
data={
"logs": logs_str,
"end_page_number": end_page_number,
"end_offset": end_offset
}
)

View File

@@ -0,0 +1,84 @@
from __future__ import annotations
import traceback
import quart
from .....core import app, taskmgr
from .. import group
@group.group_class('plugins', '/api/v1/plugins')
class PluginsRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'])
async def _() -> str:
plugins = self.ap.plugin_mgr.plugins()
plugins_data = [plugin.model_dump() for plugin in plugins]
return self.success(data={
'plugins': plugins_data
})
@self.route('/<author>/<plugin_name>/toggle', methods=['PUT'])
async def _(author: str, plugin_name: str) -> str:
data = await quart.request.json
target_enabled = data.get('target_enabled')
await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled)
return self.success()
@self.route('/<author>/<plugin_name>/update', methods=['POST'])
async def _(author: str, plugin_name: str) -> str:
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.update_plugin(plugin_name, task_context=ctx),
kind="plugin-operation",
name=f"plugin-update-{plugin_name}",
label=f"更新插件 {plugin_name}",
context=ctx
)
return self.success(data={
'task_id': wrapper.id
})
@self.route('/<author>/<plugin_name>', methods=['DELETE'])
async def _(author: str, plugin_name: str) -> str:
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.uninstall_plugin(plugin_name, task_context=ctx),
kind="plugin-operation",
name=f'plugin-remove-{plugin_name}',
label=f'删除插件 {plugin_name}',
context=ctx
)
return self.success(data={
'task_id': wrapper.id
})
@self.route('/reorder', methods=['PUT'])
async def _() -> str:
data = await quart.request.json
await self.ap.plugin_mgr.reorder_plugins(data.get('plugins'))
return self.success()
@self.route('/install/github', methods=['POST'])
async def _() -> str:
data = await quart.request.json
ctx = taskmgr.TaskContext.new()
short_source_str = data['source'][-8:]
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx),
kind="plugin-operation",
name=f'plugin-install-github',
label=f'安装插件 ...{short_source_str}',
context=ctx
)
return self.success(data={
'task_id': wrapper.id
})

View File

@@ -0,0 +1,62 @@
import quart
from .....core import app
from .. import group
@group.group_class('settings', '/api/v1/settings')
class SettingsRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'])
async def _() -> str:
return self.success(
data={
"managers": [
{
"name": m.name,
"description": m.description,
}
for m in self.ap.settings_mgr.get_manager_list()
]
}
)
@self.route('/<manager_name>', methods=['GET'])
async def _(manager_name: str) -> str:
manager = self.ap.settings_mgr.get_manager(manager_name)
if manager is None:
return self.fail(1, '配置管理器不存在')
return self.success(
data={
"manager": {
"name": manager.name,
"description": manager.description,
"schema": manager.schema,
"file": manager.file.config_file_name,
"data": manager.data,
"doc_link": manager.doc_link
}
}
)
@self.route('/<manager_name>/data', methods=['PUT'])
async def _(manager_name: str) -> str:
data = await quart.request.json
manager = self.ap.settings_mgr.get_manager(manager_name)
if manager is None:
return self.fail(code=1, msg='配置管理器不存在')
# manager.data = data['data']
for k, v in data['data'].items():
manager.data[k] = v
await manager.dump_config()
return self.success(data={
"data": manager.data
})

View File

@@ -0,0 +1,23 @@
import quart
import asyncio
from .....core import app, taskmgr
from .. import group
@group.group_class('stats', '/api/v1/stats')
class StatsRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/basic', methods=['GET'])
async def _() -> str:
conv_count = 0
for session in self.ap.sess_mgr.session_list:
conv_count += len(session.conversations if session.conversations is not None else [])
return self.success(data={
'active_session_count': len(self.ap.sess_mgr.session_list),
'conversation_count': conv_count,
'query_count': self.ap.query_pool.query_id_counter,
})

View File

@@ -0,0 +1,63 @@
import quart
import asyncio
from .....core import app, taskmgr
from .. import group
from .....utils import constants
@group.group_class('system', '/api/v1/system')
class SystemRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/info', methods=['GET'], auth_type=group.AuthType.NONE)
async def _() -> str:
return self.success(
data={
"version": constants.semantic_version,
"debug": constants.debug_mode,
"enabled_platform_count": len(self.ap.platform_mgr.adapters)
}
)
@self.route('/tasks', methods=['GET'])
async def _() -> str:
task_type = quart.request.args.get("type")
if task_type == '':
task_type = None
return self.success(
data=self.ap.task_mgr.get_tasks_dict(task_type)
)
@self.route('/tasks/<task_id>', methods=['GET'])
async def _(task_id: str) -> str:
task = self.ap.task_mgr.get_task_by_id(int(task_id))
if task is None:
return self.http_status(404, 404, "Task not found")
return self.success(data=task.to_dict())
@self.route('/reload', methods=['POST'])
async def _() -> str:
json_data = await quart.request.json
scope = json_data.get("scope")
await self.ap.reload(
scope=scope
)
return self.success()
@self.route('/_debug/exec', methods=['POST'])
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, "Forbidden")
py_code = await quart.request.data
ap = self.ap
return self.success(data=exec(py_code, {"ap": ap}))

View File

@@ -0,0 +1,47 @@
import quart
import sqlalchemy
import argon2
from .. import group
from .....persistence.entities import user
@group.group_class('user', '/api/v1/user')
class UserRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/init', methods=['GET', 'POST'], auth_type=group.AuthType.NONE)
async def _() -> str:
if quart.request.method == 'GET':
return self.success(data={
'initialized': await self.ap.user_service.is_initialized()
})
if await self.ap.user_service.is_initialized():
return self.fail(1, '系统已初始化')
json_data = await quart.request.json
user_email = json_data['user']
password = json_data['password']
await self.ap.user_service.create_user(user_email, password)
return self.success()
@self.route('/auth', methods=['POST'], auth_type=group.AuthType.NONE)
async def _() -> str:
json_data = await quart.request.json
try:
token = await self.ap.user_service.authenticate(json_data['user'], json_data['password'])
except argon2.exceptions.VerifyMismatchError:
return self.fail(1, '用户名或密码错误')
return self.success(data={
'token': token
})
@self.route('/check-token', methods=['GET'])
async def _() -> str:
return self.success()

View File

@@ -0,0 +1,73 @@
from __future__ import annotations
import asyncio
import os
import quart
import quart_cors
from ....core import app, entities as core_entities
from .groups import logs, system, settings, plugins, stats, user
from . import group
class HTTPController:
ap: app.Application
quart_app: quart.Quart
def __init__(self, ap: app.Application) -> None:
self.ap = ap
self.quart_app = quart.Quart(__name__)
quart_cors.cors(self.quart_app, allow_origin="*")
async def initialize(self) -> None:
await self.register_routes()
async def run(self) -> None:
if self.ap.system_cfg.data["http-api"]["enable"]:
async def shutdown_trigger_placeholder():
while True:
await asyncio.sleep(1)
async def exception_handler(*args, **kwargs):
try:
await self.quart_app.run_task(
*args, **kwargs
)
except Exception as e:
self.ap.logger.error(f"启动 HTTP 服务失败: {e}")
self.ap.task_mgr.create_task(
exception_handler(
host=self.ap.system_cfg.data["http-api"]["host"],
port=self.ap.system_cfg.data["http-api"]["port"],
shutdown_trigger=shutdown_trigger_placeholder,
),
name="http-api-quart",
scopes=[core_entities.LifecycleControlScope.APPLICATION],
)
# await asyncio.sleep(5)
async def register_routes(self) -> None:
@self.quart_app.route("/healthz")
async def healthz():
return {"code": 0, "msg": "ok"}
for g in group.preregistered_groups:
ginst = g(self.ap, self.quart_app)
await ginst.initialize()
frontend_path = "web/dist"
@self.quart_app.route("/")
async def index():
return await quart.send_from_directory(frontend_path, "index.html")
@self.quart_app.route("/<path:path>")
async def static_file(path: str):
return await quart.send_from_directory(frontend_path, path)

View File

View File

@@ -0,0 +1,73 @@
from __future__ import annotations
import sqlalchemy
import argon2
import jwt
import datetime
from ....core import app
from ....persistence.entities import user
from ....utils import constants
class UserService:
ap: app.Application
def __init__(self, ap: app.Application) -> None:
self.ap = ap
async def is_initialized(self) -> bool:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(user.User).limit(1)
)
result_list = result.all()
return result_list is not None and len(result_list) > 0
async def create_user(self, user_email: str, password: str) -> None:
ph = argon2.PasswordHasher()
hashed_password = ph.hash(password)
await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(user.User).values(
user=user_email,
password=hashed_password
)
)
async def authenticate(self, user_email: str, password: str) -> str | None:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(user.User).where(user.User.user == user_email)
)
result_list = result.all()
if result_list is None or len(result_list) == 0:
raise ValueError('用户不存在')
user_obj = result_list[0]
ph = argon2.PasswordHasher()
ph.verify(user_obj.password, password)
return await self.generate_jwt_token(user_email)
async def generate_jwt_token(self, user_email: str) -> str:
jwt_secret = self.ap.instance_secret_meta.data['jwt_secret']
jwt_expire = self.ap.system_cfg.data['http-api']['jwt-expire']
payload = {
'user': user_email,
'iss': 'LangBot-'+constants.edition,
'exp': datetime.datetime.now() + datetime.timedelta(seconds=jwt_expire)
}
return jwt.encode(payload, jwt_secret, algorithm='HS256')
async def verify_jwt_token(self, token: str) -> str:
jwt_secret = self.ap.instance_secret_meta.data['jwt_secret']
return jwt.decode(token, jwt_secret, algorithms=['HS256'])['user']

View File

@@ -9,11 +9,12 @@ import asyncio
import aiohttp import aiohttp
import requests import requests
from ...core import app from ...core import app, entities as core_entities
class APIGroup(metaclass=abc.ABCMeta): class APIGroup(metaclass=abc.ABCMeta):
"""API 组抽象类""" """API 组抽象类"""
_basic_info: dict = None _basic_info: dict = None
_runtime_info: dict = None _runtime_info: dict = None
@@ -32,33 +33,28 @@ class APIGroup(metaclass=abc.ABCMeta):
data: dict = None, data: dict = None,
params: dict = None, params: dict = None,
headers: dict = {}, headers: dict = {},
**kwargs **kwargs,
): ):
""" """
执行请求 执行请求
""" """
self._runtime_info['account_id'] = "-1" self._runtime_info["account_id"] = "-1"
url = self.prefix + path url = self.prefix + path
data = json.dumps(data) data = json.dumps(data)
headers['Content-Type'] = 'application/json' headers["Content-Type"] = "application/json"
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.request( async with session.request(
method, method, url, data=data, params=params, headers=headers, **kwargs
url,
data=data,
params=params,
headers=headers,
**kwargs
) as resp: ) as resp:
self.ap.logger.debug("data: %s", data) self.ap.logger.debug("data: %s", data)
self.ap.logger.debug("ret: %s", await resp.text()) self.ap.logger.debug("ret: %s", await resp.text())
except Exception as e: except Exception as e:
self.ap.logger.debug(f'上报失败: {e}') self.ap.logger.debug(f"上报失败: {e}")
async def do( async def do(
self, self,
method: str, method: str,
@@ -66,27 +62,27 @@ class APIGroup(metaclass=abc.ABCMeta):
data: dict = None, data: dict = None,
params: dict = None, params: dict = None,
headers: dict = {}, headers: dict = {},
**kwargs **kwargs,
) -> asyncio.Task: ) -> asyncio.Task:
"""执行请求""" """执行请求"""
asyncio.create_task(self._do(method, path, data, params, headers, **kwargs))
def gen_rid( return self.ap.task_mgr.create_task(
self self._do(method, path, data, params, headers, **kwargs),
): kind="telemetry-operation",
name=f"{method} {path}",
scopes=[core_entities.LifecycleControlScope.APPLICATION],
).task
def gen_rid(self):
"""生成一个请求 ID""" """生成一个请求 ID"""
return str(uuid.uuid4()) return str(uuid.uuid4())
def basic_info( def basic_info(self):
self
):
"""获取基本信息""" """获取基本信息"""
basic_info = APIGroup._basic_info.copy() basic_info = APIGroup._basic_info.copy()
basic_info['rid'] = self.gen_rid() basic_info["rid"] = self.gen_rid()
return basic_info return basic_info
def runtime_info( def runtime_info(self):
self
):
"""获取运行时信息""" """获取运行时信息"""
return APIGroup._runtime_info return APIGroup._runtime_info

View File

@@ -2,11 +2,11 @@ from __future__ import annotations
import typing import typing
import pydantic import pydantic.v1 as pydantic
import mirai
from ..core import app, entities as core_entities from ..core import app, entities as core_entities
from . import errors, operator from . import errors, operator
from ..platform.types import message as platform_message
class CommandReturn(pydantic.BaseModel): class CommandReturn(pydantic.BaseModel):
@@ -17,7 +17,7 @@ class CommandReturn(pydantic.BaseModel):
"""文本 """文本
""" """
image: typing.Optional[mirai.Image] = None image: typing.Optional[platform_message.Image] = None
"""弃用""" """弃用"""
image_url: typing.Optional[str] = None image_url: typing.Optional[str] = None

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from typing import AsyncGenerator from typing import AsyncGenerator
from .. import operator, entities, cmdmgr from .. import operator, entities, cmdmgr
from ...plugin import context as plugin_context
@operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func') @operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func')
@@ -9,16 +10,18 @@ class FuncOperator(operator.CommandOperator):
async def execute( async def execute(
self, context: entities.ExecuteContext self, context: entities.ExecuteContext
) -> AsyncGenerator[entities.CommandReturn, None]: ) -> AsyncGenerator[entities.CommandReturn, None]:
reply_str = "当前已加载的内容函数: \n\n" reply_str = "当前已启用的内容函数: \n\n"
index = 1 index = 1
all_functions = await self.ap.tool_mgr.get_all_functions() all_functions = await self.ap.tool_mgr.get_all_functions(
plugin_enabled=True,
plugin_status=plugin_context.RuntimeContainerStatus.INITIALIZED,
)
for func in all_functions: for func in all_functions:
reply_str += "{}. {}{}:\n{}\n\n".format( reply_str += "{}. {}:\n{}\n\n".format(
index, index,
("(已禁用) " if not func.enable else ""),
func.name, func.name,
func.description, func.description,
) )

View File

@@ -18,7 +18,7 @@ class PluginOperator(operator.CommandOperator):
context: entities.ExecuteContext context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
plugin_list = self.ap.plugin_mgr.plugins plugin_list = self.ap.plugin_mgr.plugins()
reply_str = "所有插件({}):\n".format(len(plugin_list)) reply_str = "所有插件({}):\n".format(len(plugin_list))
idx = 0 idx = 0
for plugin in plugin_list: for plugin in plugin_list:
@@ -110,7 +110,7 @@ class PluginUpdateAllOperator(operator.CommandOperator):
try: try:
plugins = [ plugins = [
p.plugin_name p.plugin_name
for p in self.ap.plugin_mgr.plugins for p in self.ap.plugin_mgr.plugins()
] ]
if plugins: if plugins:
@@ -163,24 +163,6 @@ class PluginDelOperator(operator.CommandOperator):
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: "+str(e))) yield entities.CommandReturn(error=errors.CommandError("插件删除失败: "+str(e)))
async def update_plugin_status(plugin_name: str, new_status: bool, ap: app.Application):
if ap.plugin_mgr.get_plugin_by_name(plugin_name) is not None:
for plugin in ap.plugin_mgr.plugins:
if plugin.plugin_name == plugin_name:
plugin.enabled = new_status
for func in plugin.content_functions:
func.enable = new_status
await ap.plugin_mgr.setting.dump_container_setting(ap.plugin_mgr.plugins)
break
return True
else:
return False
@operator.operator_class( @operator.operator_class(
name="on", name="on",
help="启用插件", help="启用插件",
@@ -200,7 +182,7 @@ class PluginEnableOperator(operator.CommandOperator):
plugin_name = context.crt_params[0] plugin_name = context.crt_params[0]
try: try:
if await update_plugin_status(plugin_name, True, self.ap): if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True):
yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name)) yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name))
else: else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name))) yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
@@ -228,7 +210,7 @@ class PluginDisableOperator(operator.CommandOperator):
plugin_name = context.crt_params[0] plugin_name = context.crt_params[0]
try: try:
if await update_plugin_status(plugin_name, False, self.ap): if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False):
yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name)) yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name))
else: else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name))) yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))

59
pkg/config/impls/yaml.py Normal file
View File

@@ -0,0 +1,59 @@
import os
import shutil
import yaml
from .. import model as file_model
class YAMLConfigFile(file_model.ConfigFile):
"""YAML配置文件"""
def __init__(
self, config_file_name: str, template_file_name: str = None, template_data: dict = None
) -> None:
self.config_file_name = config_file_name
self.template_file_name = template_file_name
self.template_data = template_data
def exists(self) -> bool:
return os.path.exists(self.config_file_name)
async def create(self):
if self.template_file_name is not None:
shutil.copyfile(self.template_file_name, self.config_file_name)
elif self.template_data is not None:
with open(self.config_file_name, "w", encoding="utf-8") as f:
yaml.dump(self.template_data, f, indent=4, allow_unicode=True)
else:
raise ValueError("template_file_name or template_data must be provided")
async def load(self, completion: bool=True) -> dict:
if not self.exists():
await self.create()
if self.template_file_name is not None:
with open(self.template_file_name, "r", encoding="utf-8") as f:
self.template_data = yaml.load(f, Loader=yaml.FullLoader)
with open(self.config_file_name, "r", encoding="utf-8") as f:
try:
cfg = yaml.load(f, Loader=yaml.FullLoader)
except yaml.YAMLError as e:
raise Exception(f"配置文件 {self.config_file_name} 语法错误: {e}")
if completion:
for key in self.template_data:
if key not in cfg:
cfg[key] = self.template_data[key]
return cfg
async def save(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f:
yaml.dump(cfg, f, indent=4, allow_unicode=True)
def save_sync(self, cfg: dict):
with open(self.config_file_name, "w", encoding="utf-8") as f:
yaml.dump(cfg, f, indent=4, allow_unicode=True)

View File

@@ -1,14 +1,22 @@
from __future__ import annotations from __future__ import annotations
from . import model as file_model from . import model as file_model
from .impls import pymodule, json as json_file from .impls import pymodule, json as json_file, yaml as yaml_file
managers: ConfigManager = []
class ConfigManager: class ConfigManager:
"""配置文件管理器""" """配置文件管理器"""
name: str = None
"""配置管理器名"""
description: str = None
"""配置管理器描述"""
schema: dict = None
"""配置文件 schema
需要符合 JSON Schema Draft 7 规范
"""
file: file_model.ConfigFile = None file: file_model.ConfigFile = None
"""配置文件实例""" """配置文件实例"""
@@ -16,6 +24,9 @@ class ConfigManager:
data: dict = None data: dict = None
"""配置数据""" """配置数据"""
doc_link: str = None
"""配置文件文档链接"""
def __init__(self, cfg_file: file_model.ConfigFile) -> None: def __init__(self, cfg_file: file_model.ConfigFile) -> None:
self.file = cfg_file self.file = cfg_file
self.data = {} self.data = {}
@@ -31,7 +42,16 @@ class ConfigManager:
async def load_python_module_config(config_name: str, template_name: str, completion: bool=True) -> ConfigManager: async def load_python_module_config(config_name: str, template_name: str, completion: bool=True) -> ConfigManager:
"""加载Python模块配置文件""" """加载Python模块配置文件
Args:
config_name (str): 配置文件名
template_name (str): 模板文件名
completion (bool): 是否自动补全内存中的配置文件
Returns:
ConfigManager: 配置文件管理器
"""
cfg_inst = pymodule.PythonModuleConfigFile( cfg_inst = pymodule.PythonModuleConfigFile(
config_name, config_name,
template_name template_name
@@ -44,7 +64,14 @@ async def load_python_module_config(config_name: str, template_name: str, comple
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager: async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager:
"""加载JSON配置文件""" """加载JSON配置文件
Args:
config_name (str): 配置文件名
template_name (str): 模板文件名
template_data (dict): 模板数据
completion (bool): 是否自动补全内存中的配置文件
"""
cfg_inst = json_file.JSONConfigFile( cfg_inst = json_file.JSONConfigFile(
config_name, config_name,
template_name, template_name,
@@ -54,4 +81,28 @@ async def load_json_config(config_name: str, template_name: str=None, template_d
cfg_mgr = ConfigManager(cfg_inst) cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config(completion=completion) await cfg_mgr.load_config(completion=completion)
return cfg_mgr return cfg_mgr
async def load_yaml_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager:
"""加载YAML配置文件
Args:
config_name (str): 配置文件名
template_name (str): 模板文件名
template_data (dict): 模板数据
completion (bool): 是否自动补全内存中的配置文件
Returns:
ConfigManager: 配置文件管理器
"""
cfg_inst = yaml_file.YAMLConfigFile(
config_name,
template_name,
template_data
)
cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config(completion=completion)
return cfg_mgr

75
pkg/config/settings.py Normal file
View File

@@ -0,0 +1,75 @@
from __future__ import annotations
from . import manager as config_manager
from ..core import app
class SettingsManager:
"""设置管理器
保存、管理多个配置文件管理器
"""
ap: app.Application
managers: list[config_manager.ConfigManager] = []
"""配置文件管理器列表"""
def __init__(self, ap: app.Application) -> None:
self.ap = ap
self.managers = []
async def initialize(self) -> None:
pass
def register_manager(
self,
name: str,
description: str,
manager: config_manager.ConfigManager,
schema: dict=None,
doc_link: str=None,
) -> None:
"""注册配置管理器
Args:
name (str): 配置管理器名
description (str): 配置管理器描述
manager (ConfigManager): 配置管理器
schema (dict): 配置文件 schema符合 JSON Schema Draft 7 规范
"""
for m in self.managers:
if m.name == name:
raise ValueError(f'配置管理器名 {name} 已存在')
manager.name = name
manager.description = description
manager.schema = schema
manager.doc_link = doc_link
self.managers.append(manager)
def get_manager(self, name: str) -> config_manager.ConfigManager | None:
"""获取配置管理器
Args:
name (str): 配置管理器名
Returns:
ConfigManager: 配置管理器
"""
for m in self.managers:
if m.name == name:
return m
return None
def get_manager_list(self) -> list[config_manager.ConfigManager]:
"""获取配置管理器列表
Returns:
list[ConfigManager]: 配置管理器列表
"""
return self.managers

View File

@@ -2,7 +2,11 @@ from __future__ import annotations
import logging import logging
import asyncio import asyncio
import threading
import traceback import traceback
import enum
import sys
import os
from ..platform import manager as im_mgr from ..platform import manager as im_mgr
from ..provider.session import sessionmgr as llm_session_mgr from ..provider.session import sessionmgr as llm_session_mgr
@@ -11,17 +15,29 @@ from ..provider.sysprompt import sysprompt as llm_prompt_mgr
from ..provider.tools import toolmgr as llm_tool_mgr from ..provider.tools import toolmgr as llm_tool_mgr
from ..provider import runnermgr from ..provider import runnermgr
from ..config import manager as config_mgr from ..config import manager as config_mgr
from ..config import settings as settings_mgr
from ..audit.center import v2 as center_mgr from ..audit.center import v2 as center_mgr
from ..command import cmdmgr from ..command import cmdmgr
from ..plugin import manager as plugin_mgr from ..plugin import manager as plugin_mgr
from ..pipeline import pool from ..pipeline import pool
from ..pipeline import controller, stagemgr from ..pipeline import controller, stagemgr
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
from ..persistence import mgr as persistencemgr
from ..api.http.controller import main as http_controller
from ..api.http.service import user as user_service
from ..utils import logcache, ip
from . import taskmgr
from . import entities as core_entities
class Application: class Application:
"""运行时应用对象和上下文""" """运行时应用对象和上下文"""
event_loop: asyncio.AbstractEventLoop = None
# asyncio_tasks: list[asyncio.Task] = []
task_mgr: taskmgr.AsyncTaskManager = None
platform_mgr: im_mgr.PlatformManager = None platform_mgr: im_mgr.PlatformManager = None
cmd_mgr: cmdmgr.CommandManager = None cmd_mgr: cmdmgr.CommandManager = None
@@ -36,6 +52,8 @@ class Application:
runner_mgr: runnermgr.RunnerManager = None runner_mgr: runnermgr.RunnerManager = None
settings_mgr: settings_mgr.SettingsManager = None
# ======= 配置管理器 ======= # ======= 配置管理器 =======
command_cfg: config_mgr.ConfigManager = None command_cfg: config_mgr.ConfigManager = None
@@ -58,6 +76,8 @@ class Application:
llm_models_meta: config_mgr.ConfigManager = None llm_models_meta: config_mgr.ConfigManager = None
instance_secret_meta: config_mgr.ConfigManager = None
# ========================= # =========================
ctr_mgr: center_mgr.V2CenterAPI = None ctr_mgr: center_mgr.V2CenterAPI = None
@@ -78,6 +98,16 @@ class Application:
logger: logging.Logger = None logger: logging.Logger = None
persistence_mgr: persistencemgr.PersistenceManager = None
http_ctrl: http_controller.HTTPController = None
log_cache: logcache.LogCache = None
# ========= HTTP Services =========
user_service: user_service.UserService = None
def __init__(self): def __init__(self):
pass pass
@@ -85,34 +115,89 @@ class Application:
pass pass
async def run(self): async def run(self):
await self.plugin_mgr.initialize_plugins()
tasks = []
try: try:
await self.plugin_mgr.initialize_plugins()
tasks = [ # 后续可能会允许动态重启其他任务
asyncio.create_task(self.platform_mgr.run()), # 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程
asyncio.create_task(self.ctrl.run()) async def never_ending():
] while True:
await asyncio.sleep(1)
# 挂信号处理 self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM])
self.task_mgr.create_task(self.ctrl.run(), name="query-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION])
import signal self.task_mgr.create_task(self.http_ctrl.run(), name="http-api-controller", scopes=[core_entities.LifecycleControlScope.APPLICATION])
self.task_mgr.create_task(never_ending(), name="never-ending-task", scopes=[core_entities.LifecycleControlScope.APPLICATION])
def signal_handler(sig, frame):
for task in tasks:
task.cancel()
self.logger.info("程序退出.")
exit(0)
signal.signal(signal.SIGINT, signal_handler)
await asyncio.gather(*tasks, return_exceptions=True)
await self.print_web_access_info()
await self.task_mgr.wait_all()
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except Exception as e: except Exception as e:
self.logger.error(f"应用运行致命异常: {e}") self.logger.error(f"应用运行致命异常: {e}")
self.logger.debug(f"Traceback: {traceback.format_exc()}") self.logger.debug(f"Traceback: {traceback.format_exc()}")
async def print_web_access_info(self):
"""打印访问 webui 的提示"""
if not os.path.exists(os.path.join(".", "web/dist")):
self.logger.warning("WebUI 文件缺失请根据文档获取https://docs.langbot.app/webui/intro.html")
return
import socket
host_ip = socket.gethostbyname(socket.gethostname())
public_ip = await ip.get_myip()
port = self.system_cfg.data['http-api']['port']
tips = f"""
=======================================
✨ 您可通过以下方式访问管理面板
🏠 本地地址http://{host_ip}:{port}/
🌐 公网地址http://{public_ip}:{port}/
📌 如果您在容器中运行此程序,请确保容器的 {port} 端口已对外暴露
🔗 若要使用公网地址访问,请阅读以下须知
1. 公网地址仅供参考,请以您的主机公网 IP 为准;
2. 要使用公网地址访问,请确保您的主机具有公网 IP并且系统防火墙已放行 {port} 端口;
🤯 WebUI 仍处于 Beta 测试阶段,如有问题或建议请反馈到 https://github.com/RockChinQ/LangBot/issues
=======================================
""".strip()
for line in tips.split("\n"):
self.logger.info(line)
async def reload(
self,
scope: core_entities.LifecycleControlScope,
):
match scope:
case core_entities.LifecycleControlScope.PLATFORM.value:
self.logger.info("执行热重载 scope="+scope)
await self.platform_mgr.shutdown()
self.platform_mgr = im_mgr.PlatformManager(self)
await self.platform_mgr.initialize()
self.task_mgr.create_task(self.platform_mgr.run(), name="platform-manager", scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM])
case core_entities.LifecycleControlScope.PLUGIN.value:
self.logger.info("执行热重载 scope="+scope)
await self.plugin_mgr.destroy_plugins()
# 删除 sys.module 中所有的 plugins/* 下的模块
for mod in list(sys.modules.keys()):
if mod.startswith("plugins."):
del sys.modules[mod]
self.plugin_mgr = plugin_mgr.PluginManager(self)
await self.plugin_mgr.initialize()
await self.plugin_mgr.initialize_plugins()
await self.plugin_mgr.load_plugins()
await self.plugin_mgr.initialize_plugins()
case _:
pass

View File

@@ -1,10 +1,13 @@
from __future__ import print_function from __future__ import print_function
import traceback import traceback
import asyncio
import os
from . import app from . import app
from ..audit import identifier from ..audit import identifier
from . import stage from . import stage
from ..utils import constants
# 引入启动阶段实现以便注册 # 引入启动阶段实现以便注册
from .stages import load_config, setup_logger, build_app, migrate, show_notes from .stages import load_config, setup_logger, build_app, migrate, show_notes
@@ -19,13 +22,19 @@ stage_order = [
] ]
async def make_app() -> app.Application: async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application:
# 生成标识符 # 生成标识符
identifier.init() identifier.init()
# 确定是否为调试模式
if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]:
constants.debug_mode = True
ap = app.Application() ap = app.Application()
ap.event_loop = loop
# 执行启动阶段 # 执行启动阶段
for stage_name in stage_order: for stage_name in stage_order:
stage_cls = stage.preregistered_stages[stage_name] stage_cls = stage.preregistered_stages[stage_name]
@@ -38,9 +47,23 @@ async def make_app() -> app.Application:
return ap return ap
async def main(): async def main(loop: asyncio.AbstractEventLoop):
try: try:
app_inst = await make_app()
# 挂系统信号处理
import signal
ap: app.Application
def signal_handler(sig, frame):
print("[Signal] 程序退出.")
# ap.shutdown()
os._exit(0)
signal.signal(signal.SIGINT, signal_handler)
app_inst = await make_app(loop)
ap = app_inst
await app_inst.run() await app_inst.run()
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()

View File

@@ -5,9 +5,8 @@ required_deps = {
"openai": "openai", "openai": "openai",
"anthropic": "anthropic", "anthropic": "anthropic",
"colorlog": "colorlog", "colorlog": "colorlog",
"mirai": "yiri-mirai-rc",
"aiocqhttp": "aiocqhttp", "aiocqhttp": "aiocqhttp",
"botpy": "qq-botpy", "botpy": "qq-botpy-rc",
"PIL": "pillow", "PIL": "pillow",
"nakuru": "nakuru-project-idk", "nakuru": "nakuru-project-idk",
"tiktoken": "tiktoken", "tiktoken": "tiktoken",
@@ -15,6 +14,15 @@ required_deps = {
"aiohttp": "aiohttp", "aiohttp": "aiohttp",
"psutil": "psutil", "psutil": "psutil",
"async_lru": "async-lru", "async_lru": "async-lru",
"ollama": "ollama",
"quart": "quart",
"quart_cors": "quart-cors",
"sqlalchemy": "sqlalchemy[asyncio]",
"aiosqlite": "aiosqlite",
"aiofiles": "aiofiles",
"aioshutil": "aioshutil",
"argon2": "argon2-cffi",
"jwt": "pyjwt",
} }

View File

@@ -5,6 +5,8 @@ import time
import colorlog import colorlog
from ...utils import constants
log_colors_config = { log_colors_config = {
"DEBUG": "green", # cyan white "DEBUG": "green", # cyan white
@@ -15,18 +17,18 @@ log_colors_config = {
} }
async def init_logging() -> logging.Logger: async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging.Logger:
# 删除所有现有的logger # 删除所有现有的logger
for handler in logging.root.handlers[:]: for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler) logging.root.removeHandler(handler)
level = logging.INFO level = logging.INFO
if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]: if constants.debug_mode:
level = logging.DEBUG level = logging.DEBUG
log_file_name = "data/logs/qcg-%s.log" % time.strftime( log_file_name = "data/logs/langbot-%s.log" % time.strftime(
"%Y-%m-%d-%H-%M-%S", time.localtime() "%Y-%m-%d", time.localtime()
) )
qcg_logger = logging.getLogger("qcg") qcg_logger = logging.getLogger("qcg")
@@ -34,14 +36,18 @@ async def init_logging() -> logging.Logger:
qcg_logger.setLevel(level) qcg_logger.setLevel(level)
color_formatter = colorlog.ColoredFormatter( color_formatter = colorlog.ColoredFormatter(
fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n %(message)s", fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(filename)s (%(lineno)d) - [%(levelname)s] : %(message)s",
datefmt="%Y-%m-%d %H:%M:%S", datefmt="%m-%d %H:%M:%S",
log_colors=log_colors_config, log_colors=log_colors_config,
) )
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
# stream_handler.setLevel(level)
# stream_handler.setFormatter(color_formatter)
stream_handler.stream = open(sys.stdout.fileno(), mode='w', encoding='utf-8', buffering=1)
log_handlers: logging.Handler = [stream_handler, logging.FileHandler(log_file_name)] log_handlers: list[logging.Handler] = [stream_handler, logging.FileHandler(log_file_name, encoding='utf-8')]
log_handlers += extra_handlers if extra_handlers is not None else []
for handler in log_handlers: for handler in log_handlers:
handler.setLevel(level) handler.setLevel(level)

View File

@@ -5,14 +5,24 @@ import typing
import datetime import datetime
import asyncio import asyncio
import pydantic import pydantic.v1 as pydantic
import mirai
from ..provider import entities as llm_entities from ..provider import entities as llm_entities
from ..provider.modelmgr import entities from ..provider.modelmgr import entities
from ..provider.sysprompt import entities as sysprompt_entities from ..provider.sysprompt import entities as sysprompt_entities
from ..provider.tools import entities as tools_entities from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter from ..platform import adapter as msadapter
from ..platform.types import message as platform_message
from ..platform.types import events as platform_events
from ..platform.types import entities as platform_entities
class LifecycleControlScope(enum.Enum):
APPLICATION = "application"
PLATFORM = "platform"
PLUGIN = "plugin"
class LauncherTypes(enum.Enum): class LauncherTypes(enum.Enum):
@@ -40,10 +50,10 @@ class Query(pydantic.BaseModel):
sender_id: int sender_id: int
"""发送者IDplatform处理阶段设置""" """发送者IDplatform处理阶段设置"""
message_event: mirai.MessageEvent message_event: platform_events.MessageEvent
"""事件platform收到的原始事件""" """事件platform收到的原始事件"""
message_chain: mirai.MessageChain message_chain: platform_message.MessageChain
"""消息链platform收到的原始消息链""" """消息链platform收到的原始消息链"""
adapter: msadapter.MessageSourceAdapter adapter: msadapter.MessageSourceAdapter
@@ -67,10 +77,10 @@ class Query(pydantic.BaseModel):
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器阶段设置""" """使用的函数,由前置处理器阶段设置"""
resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[mirai.MessageChain]] = [] resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]] = []
"""由Process阶段生成的回复消息对象列表""" """由Process阶段生成的回复消息对象列表"""
resp_message_chain: typing.Optional[list[mirai.MessageChain]] = None resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None
"""回复消息链从resp_messages包装而得""" """回复消息链从resp_messages包装而得"""
# ======= 内部保留 ======= # ======= 内部保留 =======
@@ -108,7 +118,7 @@ class Session(pydantic.BaseModel):
using_conversation: typing.Optional[Conversation] = None using_conversation: typing.Optional[Conversation] = None
conversations: typing.Optional[list[Conversation]] = [] conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("http-api-config", 13)
class HttpApiConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'http-api' not in self.ap.system_cfg.data or "persistence" not in self.ap.system_cfg.data
async def run(self):
"""执行迁移"""
self.ap.system_cfg.data['http-api'] = {
"enable": True,
"host": "0.0.0.0",
"port": 5300,
"jwt-expire": 604800
}
self.ap.system_cfg.data['persistence'] = {
"sqlite": {
"path": "data/persistence.db"
},
"use": "sqlite"
}
await self.ap.system_cfg.dump_config()

View File

@@ -0,0 +1,22 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("force-delay-config", 14)
class ForceDelayConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return type(self.ap.platform_cfg.data['force-delay']) == list
async def run(self):
"""执行迁移"""
self.ap.platform_cfg.data['force-delay'] = {
"min": self.ap.platform_cfg.data['force-delay'][0],
"max": self.ap.platform_cfg.data['force-delay'][1]
}
await self.ap.platform_cfg.dump_config()

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("gitee-ai-config", 15)
class GiteeAIConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'gitee-ai-chat-completions' not in self.ap.provider_cfg.data['requester'] or 'gitee-ai' not in self.ap.provider_cfg.data['keys']
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['requester']['gitee-ai-chat-completions'] = {
"base-url": "https://ai.gitee.com/v1",
"args": {},
"timeout": 120
}
self.ap.provider_cfg.data['keys']['gitee-ai'] = [
"XXXXX"
]
await self.ap.provider_cfg.dump_config()

View File

@@ -15,6 +15,12 @@ from ...provider.sysprompt import sysprompt as llm_prompt_mgr
from ...provider.tools import toolmgr as llm_tool_mgr from ...provider.tools import toolmgr as llm_tool_mgr
from ...provider import runnermgr from ...provider import runnermgr
from ...platform import manager as im_mgr from ...platform import manager as im_mgr
from ...persistence import mgr as persistencemgr
from ...api.http.controller import main as http_controller
from ...api.http.service import user as user_service
from ...utils import logcache
from .. import taskmgr
@stage.stage_class("BuildAppStage") @stage.stage_class("BuildAppStage")
class BuildAppStage(stage.BootingStage): class BuildAppStage(stage.BootingStage):
@@ -24,6 +30,7 @@ class BuildAppStage(stage.BootingStage):
async def run(self, ap: app.Application): async def run(self, ap: app.Application):
"""构建app对象的各个组件对象并初始化 """构建app对象的各个组件对象并初始化
""" """
ap.task_mgr = taskmgr.AsyncTaskManager(ap)
proxy_mgr = proxy.ProxyManager(ap) proxy_mgr = proxy.ProxyManager(ap)
await proxy_mgr.initialize() await proxy_mgr.initialize()
@@ -58,6 +65,13 @@ class BuildAppStage(stage.BootingStage):
ap.query_pool = pool.QueryPool() ap.query_pool = pool.QueryPool()
log_cache = logcache.LogCache()
ap.log_cache = log_cache
persistence_mgr_inst = persistencemgr.PersistenceManager(ap)
await persistence_mgr_inst.initialize()
ap.persistence_mgr = persistence_mgr_inst
plugin_mgr_inst = plugin_mgr.PluginManager(ap) plugin_mgr_inst = plugin_mgr.PluginManager(ap)
await plugin_mgr_inst.initialize() await plugin_mgr_inst.initialize()
ap.plugin_mgr = plugin_mgr_inst ap.plugin_mgr = plugin_mgr_inst
@@ -95,6 +109,12 @@ class BuildAppStage(stage.BootingStage):
await stage_mgr.initialize() await stage_mgr.initialize()
ap.stage_mgr = stage_mgr ap.stage_mgr = stage_mgr
http_ctrl = http_controller.HTTPController(ap)
await http_ctrl.initialize()
ap.http_ctrl = http_ctrl
user_service_inst = user_service.UserService(ap)
ap.user_service = user_service_inst
ctrl = controller.Controller(ap) ctrl = controller.Controller(ap)
ap.ctrl = ctrl ap.ctrl = ctrl

View File

@@ -1,7 +1,11 @@
from __future__ import annotations from __future__ import annotations
import secrets
from .. import stage, app from .. import stage, app
from ..bootutils import config from ..bootutils import config
from ...config import settings as settings_mgr
from ...utils import schema
@stage.stage_class("LoadConfigStage") @stage.stage_class("LoadConfigStage")
@@ -12,12 +16,56 @@ class LoadConfigStage(stage.BootingStage):
async def run(self, ap: app.Application): async def run(self, ap: app.Application):
"""启动 """启动
""" """
ap.settings_mgr = settings_mgr.SettingsManager(ap)
await ap.settings_mgr.initialize()
ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json", completion=False) ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json", completion=False)
ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json", completion=False) ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json", completion=False)
ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json", completion=False) ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json", completion=False)
ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json", completion=False) ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json", completion=False)
ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json", completion=False) ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json", completion=False)
ap.settings_mgr.register_manager(
name="command.json",
description="命令配置",
manager=ap.command_cfg,
schema=schema.CONFIG_COMMAND_SCHEMA,
doc_link="https://docs.langbot.app/config/function/command.html"
)
ap.settings_mgr.register_manager(
name="pipeline.json",
description="消息处理流水线配置",
manager=ap.pipeline_cfg,
schema=schema.CONFIG_PIPELINE_SCHEMA,
doc_link="https://docs.langbot.app/config/function/pipeline.html"
)
ap.settings_mgr.register_manager(
name="platform.json",
description="消息平台配置",
manager=ap.platform_cfg,
schema=schema.CONFIG_PLATFORM_SCHEMA,
doc_link="https://docs.langbot.app/config/function/platform.html"
)
ap.settings_mgr.register_manager(
name="provider.json",
description="大模型能力配置",
manager=ap.provider_cfg,
schema=schema.CONFIG_PROVIDER_SCHEMA,
doc_link="https://docs.langbot.app/config/function/provider.html"
)
ap.settings_mgr.register_manager(
name="system.json",
description="系统配置",
manager=ap.system_cfg,
schema=schema.CONFIG_SYSTEM_SCHEMA,
doc_link="https://docs.langbot.app/config/function/system.html"
)
ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json") ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json")
await ap.plugin_setting_meta.dump_config() await ap.plugin_setting_meta.dump_config()
@@ -29,3 +77,8 @@ class LoadConfigStage(stage.BootingStage):
ap.llm_models_meta = await config.load_json_config("data/metadata/llm-models.json", "templates/metadata/llm-models.json") ap.llm_models_meta = await config.load_json_config("data/metadata/llm-models.json", "templates/metadata/llm-models.json")
await ap.llm_models_meta.dump_config() await ap.llm_models_meta.dump_config()
ap.instance_secret_meta = await config.load_json_config("data/metadata/instance-secret.json", template_data={
'jwt_secret': secrets.token_hex(16)
})
await ap.instance_secret_meta.dump_config()

View File

@@ -6,7 +6,8 @@ from .. import stage, app
from .. import migration from .. import migration
from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
from ..migrations import m010_ollama_requester_config, m011_command_prefix_config, m012_runner_config from ..migrations import m010_ollama_requester_config, m011_command_prefix_config, m012_runner_config, m013_http_api_config, m014_force_delay_config
from ..migrations import m015_gitee_ai_config
@stage.stage_class("MigrationStage") @stage.stage_class("MigrationStage")
@@ -28,3 +29,4 @@ class MigrationStage(stage.BootingStage):
if await migration_instance.need_migrate(): if await migration_instance.need_migrate():
await migration_instance.run() await migration_instance.run()
print(f'已执行迁移 {migration_instance.name}')

View File

@@ -1,9 +1,38 @@
from __future__ import annotations from __future__ import annotations
import logging
import asyncio
from datetime import datetime
from .. import stage, app from .. import stage, app
from ..bootutils import log from ..bootutils import log
class PersistenceHandler(logging.Handler, object):
"""
保存日志到数据库
"""
ap: app.Application
def __init__(self, name, ap: app.Application):
logging.Handler.__init__(self)
self.ap = ap
def emit(self, record):
"""
emit函数为自定义handler类时必重写的函数这里可以根据需要对日志消息做一些处理比如发送日志到服务器
发出记录(Emit a record)
"""
try:
msg = self.format(record)
if self.ap.log_cache is not None:
self.ap.log_cache.add_log(msg)
except Exception:
self.handleError(record)
@stage.stage_class("SetupLoggerStage") @stage.stage_class("SetupLoggerStage")
class SetupLoggerStage(stage.BootingStage): class SetupLoggerStage(stage.BootingStage):
"""设置日志器阶段 """设置日志器阶段
@@ -12,4 +41,9 @@ class SetupLoggerStage(stage.BootingStage):
async def run(self, ap: app.Application): async def run(self, ap: app.Application):
"""启动 """启动
""" """
ap.logger = await log.init_logging() persistence_handler = PersistenceHandler('LoggerHandler', ap)
extra_handlers = []
extra_handlers = [persistence_handler]
ap.logger = await log.init_logging(extra_handlers)

235
pkg/core/taskmgr.py Normal file
View File

@@ -0,0 +1,235 @@
from __future__ import annotations
import asyncio
import typing
import datetime
import traceback
from . import app
from . import entities as core_entities
class TaskContext:
"""任务跟踪上下文"""
current_action: str
"""当前正在执行的动作"""
log: str
"""记录日志"""
def __init__(self):
self.current_action = "default"
self.log = ""
def _log(self, msg: str):
self.log += msg + "\n"
def set_current_action(self, action: str):
self.current_action = action
def trace(
self,
msg: str,
action: str = None,
):
if action is not None:
self.set_current_action(action)
self._log(
f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} | {self.current_action} | {msg}"
)
def to_dict(self) -> dict:
return {"current_action": self.current_action, "log": self.log}
@staticmethod
def new() -> TaskContext:
return TaskContext()
@staticmethod
def placeholder() -> TaskContext:
global placeholder_context
if placeholder_context is None:
placeholder_context = TaskContext()
return placeholder_context
placeholder_context: TaskContext | None = None
class TaskWrapper:
"""任务包装器"""
_id_index: int = 0
"""任务ID索引"""
id: int
"""任务ID"""
task_type: str = "system" # 任务类型: system 或 user
"""任务类型"""
kind: str = "system_task" # 由发起者确定任务种类,通常同质化的任务种类相同
"""任务种类"""
name: str = ""
"""任务唯一名称"""
label: str = ""
"""任务显示名称"""
task_context: TaskContext
"""任务上下文"""
task: asyncio.Task
"""任务"""
task_stack: list = None
"""任务堆栈"""
ap: app.Application
"""应用实例"""
scopes: list[core_entities.LifecycleControlScope]
"""任务所属生命周期控制范围"""
def __init__(
self,
ap: app.Application,
coro: typing.Coroutine,
task_type: str = "system",
kind: str = "system_task",
name: str = "",
label: str = "",
context: TaskContext = None,
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
):
self.id = TaskWrapper._id_index
TaskWrapper._id_index += 1
self.ap = ap
self.task_context = context or TaskContext()
self.task = self.ap.event_loop.create_task(coro)
self.task_type = task_type
self.kind = kind
self.name = name
self.label = label if label != "" else name
self.task.set_name(name)
self.scopes = scopes
def assume_exception(self):
try:
exception = self.task.exception()
if self.task_stack is None:
self.task_stack = self.task.get_stack()
return exception
except:
return None
def assume_result(self):
try:
return self.task.result()
except:
return None
def to_dict(self) -> dict:
exception_traceback = None
if self.assume_exception() is not None:
exception_traceback = 'Traceback (most recent call last):\n'
for frame in self.task_stack:
exception_traceback += f" File \"{frame.f_code.co_filename}\", line {frame.f_lineno}, in {frame.f_code.co_name}\n"
exception_traceback += f" {self.assume_exception().__str__()}\n"
return {
"id": self.id,
"task_type": self.task_type,
"kind": self.kind,
"name": self.name,
"label": self.label,
"scopes": [scope.value for scope in self.scopes],
"task_context": self.task_context.to_dict(),
"runtime": {
"done": self.task.done(),
"state": self.task._state,
"exception": self.assume_exception().__str__() if self.assume_exception() is not None else None,
"exception_traceback": exception_traceback,
"result": self.assume_result().__str__() if self.assume_result() is not None else None,
},
}
def cancel(self):
self.task.cancel()
class AsyncTaskManager:
"""保存app中的所有异步任务
包含系统级的和用户级(插件安装、更新等由用户直接发起的)的"""
ap: app.Application
tasks: list[TaskWrapper]
"""所有任务"""
def __init__(self, ap: app.Application):
self.ap = ap
self.tasks = []
def create_task(
self,
coro: typing.Coroutine,
task_type: str = "system",
kind: str = "system-task",
name: str = "",
label: str = "",
context: TaskContext = None,
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
) -> TaskWrapper:
wrapper = TaskWrapper(self.ap, coro, task_type, kind, name, label, context, scopes)
self.tasks.append(wrapper)
return wrapper
def create_user_task(
self,
coro: typing.Coroutine,
kind: str = "user-task",
name: str = "",
label: str = "",
context: TaskContext = None,
scopes: list[core_entities.LifecycleControlScope] = [core_entities.LifecycleControlScope.APPLICATION],
) -> TaskWrapper:
return self.create_task(coro, "user", kind, name, label, context, scopes)
async def wait_all(self):
await asyncio.gather(*[t.task for t in self.tasks], return_exceptions=True)
def get_all_tasks(self) -> list[TaskWrapper]:
return self.tasks
def get_tasks_dict(
self,
type: str = None,
) -> dict:
return {
"tasks": [
t.to_dict() for t in self.tasks if type is None or t.task_type == type
],
"id_index": TaskWrapper._id_index,
}
def get_task_by_id(self, id: int) -> TaskWrapper | None:
for t in self.tasks:
if t.id == id:
return t
return None
def cancel_by_scope(self, scope: core_entities.LifecycleControlScope):
for wrapper in self.tasks:
if not wrapper.task.done() and scope in wrapper.scopes:
wrapper.task.cancel()

View File

View File

@@ -0,0 +1,40 @@
from __future__ import annotations
import abc
import sqlalchemy.ext.asyncio as sqlalchemy_asyncio
from ..core import app
preregistered_managers: list[type[BaseDatabaseManager]] = []
def manager_class(name: str) -> None:
"""注册一个数据库管理类"""
def decorator(cls: type[BaseDatabaseManager]) -> type[BaseDatabaseManager]:
cls.name = name
preregistered_managers.append(cls)
return cls
return decorator
class BaseDatabaseManager(abc.ABC):
"""基础数据库管理类"""
name: str
ap: app.Application
engine: sqlalchemy_asyncio.AsyncEngine
def __init__(self, ap: app.Application) -> None:
self.ap = ap
@abc.abstractmethod
async def initialize(self) -> None:
pass
def get_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
return self.engine

View File

View File

@@ -0,0 +1,13 @@
from __future__ import annotations
import sqlalchemy.ext.asyncio as sqlalchemy_asyncio
from .. import database
@database.manager_class("sqlite")
class SQLiteDatabaseManager(database.BaseDatabaseManager):
"""SQLite 数据库管理类"""
async def initialize(self) -> None:
self.engine = sqlalchemy_asyncio.create_async_engine(f"sqlite+aiosqlite:///{self.ap.system_cfg.data['persistence']['sqlite']['path']}")

View File

View File

@@ -0,0 +1,5 @@
import sqlalchemy.orm
class Base(sqlalchemy.orm.DeclarativeBase):
pass

View File

@@ -0,0 +1,11 @@
import sqlalchemy
from .base import Base
class User(Base):
__tablename__ = 'users'
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True)
user = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
password = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)

57
pkg/persistence/mgr.py Normal file
View File

@@ -0,0 +1,57 @@
from __future__ import annotations
import asyncio
import datetime
import sqlalchemy.ext.asyncio as sqlalchemy_asyncio
import sqlalchemy
from . import database
from .entities import user, base
from ..core import app
from .databases import sqlite
class PersistenceManager:
"""持久化模块管理器"""
ap: app.Application
db: database.BaseDatabaseManager
"""数据库管理器"""
meta: sqlalchemy.MetaData
def __init__(self, ap: app.Application):
self.ap = ap
self.meta = base.Base.metadata
async def initialize(self):
for manager in database.preregistered_managers:
self.db = manager(self.ap)
await self.db.initialize()
await self.create_tables()
async def create_tables(self):
# TODO: 对扩展友好
# 日志
async with self.get_db_engine().connect() as conn:
await conn.run_sync(self.meta.create_all)
await conn.commit()
async def execute_async(
self,
*args,
**kwargs
) -> sqlalchemy.engine.cursor.CursorResult:
async with self.get_db_engine().connect() as conn:
result = await conn.execute(*args, **kwargs)
await conn.commit()
return result
def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
return self.db.get_engine()

View File

@@ -1,7 +1,5 @@
from __future__ import annotations from __future__ import annotations
import mirai
from ...core import app from ...core import app
from .. import stage, entities, stagemgr from .. import stage, entities, stagemgr
@@ -10,6 +8,9 @@ from ...config import manager as cfg_mgr
from . import filter as filter_model, entities as filter_entities from . import filter as filter_model, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine from .filters import cntignore, banwords, baiduexamine
from ...provider import entities as llm_entities from ...provider import entities as llm_entities
from ...platform.types import message as platform_message
from ...platform.types import events as platform_events
from ...platform.types import entities as platform_entities
@stage.stage_class('PostContentFilterStage') @stage.stage_class('PostContentFilterStage')
@@ -63,6 +64,7 @@ class ContentFilterStage(stage.PipelineStage):
"""请求llm前处理消息 """请求llm前处理消息
只要有一个不通过就不放行,只放行 PASS 的消息 只要有一个不通过就不放行,只放行 PASS 的消息
""" """
if not self.ap.pipeline_cfg.data['income-msg-check']: if not self.ap.pipeline_cfg.data['income-msg-check']:
return entities.StageProcessResult( return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
@@ -86,8 +88,8 @@ class ContentFilterStage(stage.PipelineStage):
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个 elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
message = result.replacement message = result.replacement
query.message_chain = mirai.MessageChain( query.message_chain = platform_message.MessageChain(
mirai.Plain(message) platform_message.Plain(message)
) )
return entities.StageProcessResult( return entities.StageProcessResult(
@@ -145,11 +147,13 @@ class ContentFilterStage(stage.PipelineStage):
contain_non_text = False contain_non_text = False
text_components = [platform_message.Plain, platform_message.Source]
for me in query.message_chain: for me in query.message_chain:
if not isinstance(me, mirai.Plain): if type(me) not in text_components:
contain_non_text = True contain_non_text = True
break break
if contain_non_text: if contain_non_text:
self.ap.logger.debug(f"消息中包含非文本消息,跳过内容过滤器检查。") self.ap.logger.debug(f"消息中包含非文本消息,跳过内容过滤器检查。")
return entities.StageProcessResult( return entities.StageProcessResult(

View File

@@ -2,7 +2,7 @@
import typing import typing
import enum import enum
import pydantic import pydantic.v1 as pydantic
from ...provider import entities as llm_entities from ...provider import entities as llm_entities

View File

@@ -4,11 +4,10 @@ import asyncio
import typing import typing
import traceback import traceback
import mirai
from ..core import app, entities from ..core import app, entities
from . import entities as pipeline_entities from . import entities as pipeline_entities
from ..plugin import events from ..plugin import events
from ..platform.types import message as platform_message
class Controller: class Controller:
@@ -59,8 +58,13 @@ class Controller:
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release() (await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
# 通知其他协程,有新的请求可以处理了 # 通知其他协程,有新的请求可以处理了
self.ap.query_pool.condition.notify_all() self.ap.query_pool.condition.notify_all()
self.ap.task_mgr.create_task(
asyncio.create_task(_process_query(selected_query)) _process_query(selected_query),
kind="query",
name=f"query-{selected_query.query_id}",
scopes=[entities.LifecycleControlScope.APPLICATION, entities.LifecycleControlScope.PLATFORM],
)
except Exception as e: except Exception as e:
# traceback.print_exc() # traceback.print_exc()
self.ap.logger.error(f"控制器循环出错: {e}") self.ap.logger.error(f"控制器循环出错: {e}")
@@ -73,11 +77,11 @@ class Controller:
# 处理str类型 # 处理str类型
if isinstance(result.user_notice, str): if isinstance(result.user_notice, str):
result.user_notice = mirai.MessageChain( result.user_notice = platform_message.MessageChain(
mirai.Plain(result.user_notice) platform_message.Plain(result.user_notice)
) )
elif isinstance(result.user_notice, list): elif isinstance(result.user_notice, list):
result.user_notice = mirai.MessageChain( result.user_notice = platform_message.MessageChain(
*result.user_notice *result.user_notice
) )
@@ -159,6 +163,23 @@ class Controller:
async def process_query(self, query: entities.Query): async def process_query(self, query: entities.Query):
"""处理请求 """处理请求
""" """
# ======== 触发 MessageReceived 事件 ========
event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_type(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
message_chain=query.message_chain,
query=query
)
)
if event_ctx.is_prevented_default():
return
self.ap.logger.debug(f"Processing query {query}") self.ap.logger.debug(f"Processing query {query}")
try: try:
@@ -166,7 +187,6 @@ class Controller:
except Exception as e: except Exception as e:
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={query.current_stage.inst_name} : {e}") self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={query.current_stage.inst_name} : {e}")
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
# traceback.print_exc()
finally: finally:
self.ap.logger.debug(f"Query {query} processed") self.ap.logger.debug(f"Query {query} processed")

View File

@@ -3,9 +3,8 @@ from __future__ import annotations
import enum import enum
import typing import typing
import pydantic import pydantic.v1 as pydantic
import mirai from ..platform.types import message as platform_message
import mirai.models.message as mirai_message
from ..core import entities from ..core import entities
@@ -25,13 +24,9 @@ class StageProcessResult(pydantic.BaseModel):
new_query: entities.Query new_query: entities.Query
user_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = [] user_notice: typing.Optional[typing.Union[str, list[platform_message.MessageComponent], platform_message.MessageChain, None]] = []
"""只要设置了就会发送给用户""" """只要设置了就会发送给用户"""
# TODO delete
# admin_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = []
"""只要设置了就会发送给管理员"""
console_notice: typing.Optional[str] = '' console_notice: typing.Optional[str] = ''
"""只要设置了就会输出到控制台""" """只要设置了就会输出到控制台"""

View File

@@ -3,7 +3,6 @@ import os
import traceback import traceback
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from mirai.models.message import MessageComponent, Plain, MessageChain
from ...core import app from ...core import app
from . import strategy from . import strategy
@@ -11,6 +10,7 @@ from .strategies import image, forward
from .. import stage, entities, stagemgr from .. import stage, entities, stagemgr
from ...core import entities as core_entities from ...core import entities as core_entities
from ...config import manager as cfg_mgr from ...config import manager as cfg_mgr
from ...platform.types import message as platform_message
@stage.stage_class("LongTextProcessStage") @stage.stage_class("LongTextProcessStage")
@@ -63,14 +63,14 @@ class LongTextProcessStage(stage.PipelineStage):
contains_non_plain = False contains_non_plain = False
for msg in query.resp_message_chain[-1]: for msg in query.resp_message_chain[-1]:
if not isinstance(msg, Plain): if not isinstance(msg, platform_message.Plain):
contains_non_plain = True contains_non_plain = True
break break
if contains_non_plain: if contains_non_plain:
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。") self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']: elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']:
query.resp_message_chain[-1] = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query)) query.resp_message_chain[-1] = platform_message.MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
return entities.StageProcessResult( return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,

View File

@@ -2,15 +2,14 @@
from __future__ import annotations from __future__ import annotations
import typing import typing
from mirai.models import MessageChain import pydantic.v1 as pydantic
from mirai.models.message import MessageComponent, ForwardMessageNode
from mirai.models.base import MiraiBaseModel
from .. import strategy as strategy_model from .. import strategy as strategy_model
from ....core import entities as core_entities from ....core import entities as core_entities
from ....platform.types import message as platform_message
class ForwardMessageDiaplay(MiraiBaseModel): class ForwardMessageDiaplay(pydantic.BaseModel):
title: str = "群聊的聊天记录" title: str = "群聊的聊天记录"
brief: str = "[聊天记录]" brief: str = "[聊天记录]"
source: str = "聊天记录" source: str = "聊天记录"
@@ -18,13 +17,13 @@ class ForwardMessageDiaplay(MiraiBaseModel):
summary: str = "查看x条转发消息" summary: str = "查看x条转发消息"
class Forward(MessageComponent): class Forward(platform_message.MessageComponent):
"""合并转发。""" """合并转发。"""
type: str = "Forward" type: str = "Forward"
"""消息组件类型。""" """消息组件类型。"""
display: ForwardMessageDiaplay display: ForwardMessageDiaplay
"""显示信息""" """显示信息"""
node_list: typing.List[ForwardMessageNode] node_list: typing.List[platform_message.ForwardMessageNode]
"""转发消息节点列表。""" """转发消息节点列表。"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if len(args) == 1: if len(args) == 1:
@@ -39,7 +38,7 @@ class Forward(MessageComponent):
@strategy_model.strategy_class("forward") @strategy_model.strategy_class("forward")
class ForwardComponentStrategy(strategy_model.LongTextStrategy): class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
display = ForwardMessageDiaplay( display = ForwardMessageDiaplay(
title="群聊的聊天记录", title="群聊的聊天记录",
brief="[聊天记录]", brief="[聊天记录]",
@@ -49,10 +48,10 @@ class ForwardComponentStrategy(strategy_model.LongTextStrategy):
) )
node_list = [ node_list = [
ForwardMessageNode( platform_message.ForwardMessageNode(
sender_id=query.adapter.bot_account_id, sender_id=query.adapter.bot_account_id,
sender_name='QQ用户', sender_name='QQ用户',
message_chain=MessageChain([message]) message_chain=platform_message.MessageChain([message])
) )
] ]

View File

@@ -8,8 +8,7 @@ import re
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from mirai.models import MessageChain, Image as ImageComponent from ....platform.types import message as platform_message
from mirai.models.message import MessageComponent
from .. import strategy as strategy_model from .. import strategy as strategy_model
from ....core import entities as core_entities from ....core import entities as core_entities
@@ -23,7 +22,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
async def initialize(self): async def initialize(self):
self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8") self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8")
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
img_path = self.text_to_image( img_path = self.text_to_image(
text_str=message, text_str=message,
save_as='temp/{}.png'.format(int(time.time())) save_as='temp/{}.png'.format(int(time.time()))
@@ -46,7 +45,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
os.remove(compressed_path) os.remove(compressed_path)
return [ return [
ImageComponent( platform_message.Image(
base64=b64.decode('utf-8'), base64=b64.decode('utf-8'),
) )
] ]
@@ -59,7 +58,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
""" """
kv = [] kv = []
nums = [] nums = []
beforeDatas = re.findall('[\d]+', path) beforeDatas = re.findall('[\\d]+', path)
for num in beforeDatas: for num in beforeDatas:
indexV = [] indexV = []
times = path.count(num) times = path.count(num)

View File

@@ -2,11 +2,10 @@ from __future__ import annotations
import abc import abc
import typing import typing
import mirai
from mirai.models.message import MessageComponent
from ...core import app from ...core import app
from ...core import entities as core_entities from ...core import entities as core_entities
from ...platform.types import message as platform_message
preregistered_strategies: list[typing.Type[LongTextStrategy]] = [] preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
@@ -51,7 +50,7 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
"""处理长文本 """处理长文本
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法 在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法
@@ -61,6 +60,6 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
query (core_entities.Query): 此次请求的上下文对象 query (core_entities.Query): 此次请求的上下文对象
Returns: Returns:
list[mirai.models.messages.MessageComponent]: 转换后的 YiriMirai 消息组件列表 list[platform_message.MessageComponent]: 转换后的 平台 消息组件列表
""" """
return [] return []

View File

@@ -2,10 +2,11 @@ from __future__ import annotations
import asyncio import asyncio
import mirai
from ..core import entities from ..core import entities
from ..platform import adapter as msadapter from ..platform import adapter as msadapter
from ..platform.types import message as platform_message
from ..platform.types import events as platform_events
class QueryPool: class QueryPool:
@@ -30,8 +31,8 @@ class QueryPool:
launcher_type: entities.LauncherTypes, launcher_type: entities.LauncherTypes,
launcher_id: int, launcher_id: int,
sender_id: int, sender_id: int,
message_event: mirai.MessageEvent, message_event: platform_events.MessageEvent,
message_chain: mirai.MessageChain, message_chain: platform_message.MessageChain,
adapter: msadapter.MessageSourceAdapter adapter: msadapter.MessageSourceAdapter
) -> entities.Query: ) -> entities.Query:
async with self.condition: async with self.condition:

View File

@@ -1,11 +1,11 @@
from __future__ import annotations from __future__ import annotations
import mirai
from .. import stage, entities, stagemgr from .. import stage, entities, stagemgr
from ...core import entities as core_entities from ...core import entities as core_entities
from ...provider import entities as llm_entities from ...provider import entities as llm_entities
from ...plugin import events from ...plugin import events
from ...platform.types import message as platform_message
@stage.stage_class("PreProcessor") @stage.stage_class("PreProcessor")
@@ -55,11 +55,11 @@ class PreProcessor(stage.PipelineStage):
content_list = [] content_list = []
for me in query.message_chain: for me in query.message_chain:
if isinstance(me, mirai.Plain): if isinstance(me, platform_message.Plain):
content_list.append( content_list.append(
llm_entities.ContentElement.from_text(me.text) llm_entities.ContentElement.from_text(me.text)
) )
elif isinstance(me, mirai.Image): elif isinstance(me, platform_message.Image):
if self.ap.provider_cfg.data['enable-vision'] and query.use_model.vision_supported: if self.ap.provider_cfg.data['enable-vision'] and query.use_model.vision_supported:
if me.url is not None: if me.url is not None:
content_list.append( content_list.append(

View File

@@ -5,7 +5,6 @@ import time
import traceback import traceback
import json import json
import mirai
from .. import handler from .. import handler
from ... import entities from ... import entities
@@ -13,6 +12,8 @@ from ....core import entities as core_entities
from ....provider import entities as llm_entities, runnermgr from ....provider import entities as llm_entities, runnermgr
from ....plugin import events from ....plugin import events
from ....platform.types import message as platform_message
class ChatMessageHandler(handler.MessageHandler): class ChatMessageHandler(handler.MessageHandler):
@@ -40,7 +41,7 @@ class ChatMessageHandler(handler.MessageHandler):
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
mc = mirai.MessageChain(event_ctx.event.reply) mc = platform_message.MessageChain(event_ctx.event.reply)
query.resp_messages.append(mc) query.resp_messages.append(mc)

View File

@@ -1,13 +1,13 @@
from __future__ import annotations from __future__ import annotations
import typing import typing
import mirai
from .. import handler from .. import handler
from ... import entities from ... import entities
from ....core import entities as core_entities from ....core import entities as core_entities
from ....provider import entities as llm_entities from ....provider import entities as llm_entities
from ....plugin import events from ....plugin import events
from ....platform.types import message as platform_message
class CommandHandler(handler.MessageHandler): class CommandHandler(handler.MessageHandler):
@@ -46,7 +46,7 @@ class CommandHandler(handler.MessageHandler):
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
mc = mirai.MessageChain(event_ctx.event.reply) mc = platform_message.MessageChain(event_ctx.event.reply)
query.resp_messages.append(mc) query.resp_messages.append(mc)
@@ -63,8 +63,8 @@ class CommandHandler(handler.MessageHandler):
else: else:
if event_ctx.event.alter is not None: if event_ctx.event.alter is not None:
query.message_chain = mirai.MessageChain([ query.message_chain = platform_message.MessageChain([
mirai.Plain(event_ctx.event.alter) platform_message.Plain(event_ctx.event.alter)
]) ])
session = await self.ap.sess_mgr.get_session(query) session = await self.ap.sess_mgr.get_session(query)

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import random import random
import asyncio import asyncio
import mirai
from ...core import app from ...core import app
@@ -20,7 +19,10 @@ class SendResponseBackStage(stage.PipelineStage):
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理 """处理
""" """
random_delay = random.uniform(*self.ap.platform_cfg.data['force-delay'])
random_range = (self.ap.platform_cfg.data['force-delay']['min'], self.ap.platform_cfg.data['force-delay']['max'])
random_delay = random.uniform(*random_range)
self.ap.logger.debug( self.ap.logger.debug(
"根据规则强制延迟回复: %s s", "根据规则强制延迟回复: %s s",

View File

@@ -1,9 +1,10 @@
import pydantic import pydantic.v1 as pydantic
import mirai
from ...platform.types import message as platform_message
class RuleJudgeResult(pydantic.BaseModel): class RuleJudgeResult(pydantic.BaseModel):
matching: bool = False matching: bool = False
replacement: mirai.MessageChain = None replacement: platform_message.MessageChain = None

View File

@@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import mirai
from ...core import app from ...core import app
from . import entities as rule_entities, rule from . import entities as rule_entities, rule

View File

@@ -2,11 +2,11 @@ from __future__ import annotations
import abc import abc
import typing import typing
import mirai
from ...core import app, entities as core_entities from ...core import app, entities as core_entities
from . import entities from . import entities
from ...platform.types import message as platform_message
preregisetered_rules: list[typing.Type[GroupRespondRule]] = [] preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
@@ -35,7 +35,7 @@ class GroupRespondRule(metaclass=abc.ABCMeta):
async def match( async def match(
self, self,
message_text: str, message_text: str,
message_chain: mirai.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query query: core_entities.Query
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:

View File

@@ -1,10 +1,10 @@
from __future__ import annotations from __future__ import annotations
import mirai
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("at-bot") @rule_model.rule_class("at-bot")
@@ -13,16 +13,16 @@ class AtBotRule(rule_model.GroupRespondRule):
async def match( async def match(
self, self,
message_text: str, message_text: str,
message_chain: mirai.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query query: core_entities.Query
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']: if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
message_chain.remove(mirai.At(query.adapter.bot_account_id)) message_chain.remove(platform_message.At(query.adapter.bot_account_id))
if message_chain.has(mirai.At(query.adapter.bot_account_id)): # 回复消息时会at两次检查并删除重复的 if message_chain.has(platform_message.At(query.adapter.bot_account_id)): # 回复消息时会at两次检查并删除重复的
message_chain.remove(mirai.At(query.adapter.bot_account_id)) message_chain.remove(platform_message.At(query.adapter.bot_account_id))
return entities.RuleJudgeResult( return entities.RuleJudgeResult(
matching=True, matching=True,

View File

@@ -1,8 +1,8 @@
import mirai
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("prefix") @rule_model.rule_class("prefix")
@@ -11,7 +11,7 @@ class PrefixRule(rule_model.GroupRespondRule):
async def match( async def match(
self, self,
message_text: str, message_text: str,
message_chain: mirai.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query query: core_entities.Query
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
@@ -22,7 +22,7 @@ class PrefixRule(rule_model.GroupRespondRule):
# 查找第一个plain元素 # 查找第一个plain元素
for me in message_chain: for me in message_chain:
if isinstance(me, mirai.Plain): if isinstance(me, platform_message.Plain):
me.text = me.text[len(prefix):] me.text = me.text[len(prefix):]
return entities.RuleJudgeResult( return entities.RuleJudgeResult(

View File

@@ -1,10 +1,10 @@
import random import random
import mirai
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("random") @rule_model.rule_class("random")
@@ -13,7 +13,7 @@ class RandomRespRule(rule_model.GroupRespondRule):
async def match( async def match(
self, self,
message_text: str, message_text: str,
message_chain: mirai.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query query: core_entities.Query
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:

View File

@@ -1,10 +1,10 @@
import re import re
import mirai
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("regexp") @rule_model.rule_class("regexp")
@@ -13,7 +13,7 @@ class RegExpRule(rule_model.GroupRespondRule):
async def match( async def match(
self, self,
message_text: str, message_text: str,
message_chain: mirai.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query query: core_entities.Query
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:

View File

@@ -1,7 +1,5 @@
from __future__ import annotations from __future__ import annotations
import pydantic
from ..core import app from ..core import app
from . import stage from . import stage
from .resprule import resprule from .resprule import resprule

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import typing import typing
import mirai
from ...core import app, entities as core_entities from ...core import app, entities as core_entities
from .. import entities from .. import entities
@@ -10,6 +9,7 @@ from .. import stage, entities, stagemgr
from ...core import entities as core_entities from ...core import entities as core_entities
from ...config import manager as cfg_mgr from ...config import manager as cfg_mgr
from ...plugin import events from ...plugin import events
from ...platform.types import message as platform_message
@stage.stage_class("ResponseWrapper") @stage.stage_class("ResponseWrapper")
@@ -34,7 +34,7 @@ class ResponseWrapper(stage.PipelineStage):
""" """
# 如果 resp_messages[-1] 已经是 MessageChain 了 # 如果 resp_messages[-1] 已经是 MessageChain 了
if isinstance(query.resp_messages[-1], mirai.MessageChain): if isinstance(query.resp_messages[-1], platform_message.MessageChain):
query.resp_message_chain.append(query.resp_messages[-1]) query.resp_message_chain.append(query.resp_messages[-1])
yield entities.StageProcessResult( yield entities.StageProcessResult(
@@ -45,19 +45,14 @@ class ResponseWrapper(stage.PipelineStage):
else: else:
if query.resp_messages[-1].role == 'command': if query.resp_messages[-1].role == 'command':
# query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content)) query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain(prefix_text='[bot] '))
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] '))
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )
elif query.resp_messages[-1].role == 'plugin': elif query.resp_messages[-1].role == 'plugin':
# if not isinstance(query.resp_messages[-1].content, mirai.MessageChain): query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain())
# query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content))
# else:
# query.resp_message_chain.append(query.resp_messages[-1].content)
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain())
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
@@ -72,7 +67,7 @@ class ResponseWrapper(stage.PipelineStage):
reply_text = '' reply_text = ''
if result.content: # 有内容 if result.content: # 有内容
reply_text = str(result.get_content_mirai_message_chain()) reply_text = str(result.get_content_platform_message_chain())
# ============= 触发插件事件 =============== # ============= 触发插件事件 ===============
event_ctx = await self.ap.plugin_mgr.emit_event( event_ctx = await self.ap.plugin_mgr.emit_event(
@@ -96,11 +91,11 @@ class ResponseWrapper(stage.PipelineStage):
else: else:
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply)) query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
else: else:
query.resp_message_chain.append(result.get_content_mirai_message_chain()) query.resp_message_chain.append(result.get_content_platform_message_chain())
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
@@ -113,7 +108,7 @@ class ResponseWrapper(stage.PipelineStage):
reply_text = f'调用函数 {".".join(function_names)}...' reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)])) query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
if self.ap.platform_cfg.data['track-function-calls']: if self.ap.platform_cfg.data['track-function-calls']:
@@ -139,11 +134,11 @@ class ResponseWrapper(stage.PipelineStage):
else: else:
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply)) query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
else: else:
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)])) query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,

View File

@@ -4,9 +4,10 @@ from __future__ import annotations
import typing import typing
import abc import abc
import mirai
from ..core import app from ..core import app
from .types import message as platform_message
from .types import events as platform_events
preregistered_adapters: list[typing.Type[MessageSourceAdapter]] = [] preregistered_adapters: list[typing.Type[MessageSourceAdapter]] = []
@@ -55,28 +56,28 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
self, self,
target_type: str, target_type: str,
target_id: str, target_id: str,
message: mirai.MessageChain message: platform_message.MessageChain
): ):
"""主动发送消息 """主动发送消息
Args: Args:
target_type (str): 目标类型,`person`或`group` target_type (str): 目标类型,`person`或`group`
target_id (str): 目标ID target_id (str): 目标ID
message (mirai.MessageChain): YiriMirai库的消息链 message (platform.types.MessageChain): 消息链
""" """
raise NotImplementedError raise NotImplementedError
async def reply_message( async def reply_message(
self, self,
message_source: mirai.MessageEvent, message_source: platform_events.MessageEvent,
message: mirai.MessageChain, message: platform_message.MessageChain,
quote_origin: bool = False quote_origin: bool = False
): ):
"""回复消息 """回复消息
Args: Args:
message_source (mirai.MessageEvent): YiriMirai消息源事件 message_source (platform.types.MessageEvent): 消息源事件
message (mirai.MessageChain): YiriMirai库的消息链 message (platform.types.MessageChain): 消息链
quote_origin (bool, optional): 是否引用原消息. Defaults to False. quote_origin (bool, optional): 是否引用原消息. Defaults to False.
""" """
raise NotImplementedError raise NotImplementedError
@@ -87,27 +88,27 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[platform_message.Event],
callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None] callback: typing.Callable[[platform_message.Event, MessageSourceAdapter], None]
): ):
"""注册事件监听器 """注册事件监听器
Args: Args:
event_type (typing.Type[mirai.Event]): YiriMirai事件类型 event_type (typing.Type[platform.types.Event]): 事件类型
callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件 callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件
""" """
raise NotImplementedError raise NotImplementedError
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[platform_message.Event],
callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None] callback: typing.Callable[[platform_message.Event, MessageSourceAdapter], None]
): ):
"""注销事件监听器 """注销事件监听器
Args: Args:
event_type (typing.Type[mirai.Event]): YiriMirai事件类型 event_type (typing.Type[platform.types.Event]): 事件类型
callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件 callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件
""" """
raise NotImplementedError raise NotImplementedError
@@ -127,26 +128,26 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
class MessageConverter: class MessageConverter:
"""消息链转换器基类""" """消息链转换器基类"""
@staticmethod @staticmethod
def yiri2target(message_chain: mirai.MessageChain): def yiri2target(message_chain: platform_message.MessageChain):
"""YiriMirai消息链转换为目标消息链 """源平台消息链转换为目标平台消息链
Args: Args:
message_chain (mirai.MessageChain): YiriMirai消息链 message_chain (platform.types.MessageChain): 源平台消息链
Returns: Returns:
typing.Any: 目标消息链 typing.Any: 目标平台消息链
""" """
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
def target2yiri(message_chain: typing.Any) -> mirai.MessageChain: def target2yiri(message_chain: typing.Any) -> platform_message.MessageChain:
"""将目标消息链转换为YiriMirai消息链 """将目标平台消息链转换为源平台消息链
Args: Args:
message_chain (typing.Any): 目标消息链 message_chain (typing.Any): 目标平台消息链
Returns: Returns:
mirai.MessageChain: YiriMirai消息链 platform.types.MessageChain: 源平台消息链
""" """
raise NotImplementedError raise NotImplementedError
@@ -155,25 +156,25 @@ class EventConverter:
"""事件转换器基类""" """事件转换器基类"""
@staticmethod @staticmethod
def yiri2target(event: typing.Type[mirai.Event]): def yiri2target(event: typing.Type[platform_message.Event]):
"""YiriMirai事件转换为目标事件 """源平台事件转换为目标平台事件
Args: Args:
event (typing.Type[mirai.Event]): YiriMirai事件 event (typing.Type[platform.types.Event]): 源平台事件
Returns: Returns:
typing.Any: 目标事件 typing.Any: 目标平台事件
""" """
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
def target2yiri(event: typing.Any) -> mirai.Event: def target2yiri(event: typing.Any) -> platform_message.Event:
"""将目标事件的调用参数转换为YiriMirai的事件参数对象 """将目标平台事件的调用参数转换为源平台的事件参数对象
Args: Args:
event (typing.Any): 目标事件 event (typing.Any): 目标平台事件
Returns: Returns:
typing.Type[mirai.Event]: YiriMirai事件 typing.Type[platform.types.Event]: 源平台事件
""" """
raise NotImplementedError raise NotImplementedError

View File

@@ -2,17 +2,24 @@ from __future__ import annotations
import json import json
import os import os
import sys
import logging import logging
import asyncio import asyncio
import traceback import traceback
from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \ # FriendMessage, Image, MessageChain, Plain
FriendMessage, Image, MessageChain, Plain
import mirai
from ..platform import adapter as msadapter from ..platform import adapter as msadapter
from ..core import app, entities as core_entities from ..core import app, entities as core_entities
from ..plugin import events from ..plugin import events
from .types import message as platform_message
from .types import events as platform_events
from .types import entities as platform_entities
# 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题
from . import types as mirai
sys.modules['mirai'] = mirai
# 控制QQ消息输入输出的类 # 控制QQ消息输入输出的类
class PlatformManager: class PlatformManager:
@@ -30,76 +37,40 @@ class PlatformManager:
async def initialize(self): async def initialize(self):
from .sources import yirimirai, nakuru, aiocqhttp, qqbotpy from .sources import nakuru, aiocqhttp, qqbotpy
async def on_friend_message(event: FriendMessage, adapter: msadapter.MessageSourceAdapter): async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessageSourceAdapter):
event_ctx = await self.ap.plugin_mgr.emit_event( await self.ap.query_pool.add_query(
event=events.PersonMessageReceived( launcher_type=core_entities.LauncherTypes.PERSON,
launcher_type='person', launcher_id=event.sender.id,
launcher_id=event.sender.id, sender_id=event.sender.id,
sender_id=event.sender.id, message_event=event,
message_chain=event.message_chain, message_chain=event.message_chain,
query=None adapter=adapter
)
) )
if not event_ctx.is_prevented_default(): async def on_stranger_message(event: platform_events.StrangerMessage, adapter: msadapter.MessageSourceAdapter):
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.PERSON,
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain,
adapter=adapter
)
async def on_stranger_message(event: StrangerMessage, adapter: msadapter.MessageSourceAdapter):
event_ctx = await self.ap.plugin_mgr.emit_event( await self.ap.query_pool.add_query(
event=events.PersonMessageReceived( launcher_type=core_entities.LauncherTypes.PERSON,
launcher_type='person', launcher_id=event.sender.id,
launcher_id=event.sender.id, sender_id=event.sender.id,
sender_id=event.sender.id, message_event=event,
message_chain=event.message_chain, message_chain=event.message_chain,
query=None adapter=adapter
)
) )
if not event_ctx.is_prevented_default(): async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessageSourceAdapter):
await self.ap.query_pool.add_query( await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.PERSON, launcher_type=core_entities.LauncherTypes.GROUP,
launcher_id=event.sender.id, launcher_id=event.group.id,
sender_id=event.sender.id, sender_id=event.sender.id,
message_event=event, message_event=event,
message_chain=event.message_chain, message_chain=event.message_chain,
adapter=adapter adapter=adapter
)
async def on_group_message(event: GroupMessage, adapter: msadapter.MessageSourceAdapter):
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.GroupMessageReceived(
launcher_type='group',
launcher_id=event.group.id,
sender_id=event.sender.id,
message_chain=event.message_chain,
query=None
)
) )
if not event_ctx.is_prevented_default():
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.GROUP,
launcher_id=event.group.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain,
adapter=adapter
)
index = 0 index = 0
@@ -127,16 +98,16 @@ class PlatformManager:
if adapter_name == 'yiri-mirai': if adapter_name == 'yiri-mirai':
adapter_inst.register_listener( adapter_inst.register_listener(
StrangerMessage, platform_events.StrangerMessage,
on_stranger_message on_stranger_message
) )
adapter_inst.register_listener( adapter_inst.register_listener(
FriendMessage, platform_events.FriendMessage,
on_friend_message on_friend_message
) )
adapter_inst.register_listener( adapter_inst.register_listener(
GroupMessage, platform_events.GroupMessage,
on_group_message on_group_message
) )
@@ -146,13 +117,13 @@ class PlatformManager:
if len(self.adapters) == 0: if len(self.adapters) == 0:
self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。') self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。')
async def send(self, event: mirai.MessageEvent, msg: mirai.MessageChain, adapter: msadapter.MessageSourceAdapter): async def send(self, event: platform_events.MessageEvent, msg: platform_message.MessageChain, adapter: msadapter.MessageSourceAdapter):
if self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage): if self.ap.platform_cfg.data['at-sender'] and isinstance(event, platform_events.GroupMessage):
msg.insert( msg.insert(
0, 0,
At( platform_message.At(
event.sender.id event.sender.id
) )
) )
@@ -167,19 +138,30 @@ class PlatformManager:
try: try:
tasks = [] tasks = []
for adapter in self.adapters: for adapter in self.adapters:
async def exception_wrapper(adapter): async def exception_wrapper(adapter: msadapter.MessageSourceAdapter):
try: try:
await adapter.run_async() await adapter.run_async()
except Exception as e: except Exception as e:
if isinstance(e, asyncio.CancelledError):
return
self.ap.logger.error('平台适配器运行出错: ' + str(e)) self.ap.logger.error('平台适配器运行出错: ' + str(e))
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
tasks.append(exception_wrapper(adapter)) tasks.append(exception_wrapper(adapter))
for task in tasks: for task in tasks:
asyncio.create_task(task) self.ap.task_mgr.create_task(
task,
kind="platform-adapter",
name=f"platform-adapter-{adapter.name}",
scopes=[core_entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.PLATFORM],
)
except Exception as e: except Exception as e:
self.ap.logger.error('平台适配器运行出错: ' + str(e)) self.ap.logger.error('平台适配器运行出错: ' + str(e))
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
async def shutdown(self):
for adapter in self.adapters:
await adapter.kill()
self.ap.task_mgr.cancel_by_scope(core_entities.LifecycleControlScope.PLATFORM)

View File

@@ -5,31 +5,32 @@ import traceback
import time import time
import datetime import datetime
import mirai
import mirai.models.message as yiri_message
import aiocqhttp import aiocqhttp
from .. import adapter from .. import adapter
from ...pipeline.longtext.strategies import forward from ...pipeline.longtext.strategies import forward
from ...core import app from ...core import app
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities
class AiocqhttpMessageConverter(adapter.MessageConverter): class AiocqhttpMessageConverter(adapter.MessageConverter):
@staticmethod @staticmethod
def yiri2target(message_chain: mirai.MessageChain) -> typing.Tuple[list, int, datetime.datetime]: def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[list, int, datetime.datetime]:
msg_list = aiocqhttp.Message() msg_list = aiocqhttp.Message()
msg_id = 0 msg_id = 0
msg_time = None msg_time = None
for msg in message_chain: for msg in message_chain:
if type(msg) is mirai.Plain: if type(msg) is platform_message.Plain:
msg_list.append(aiocqhttp.MessageSegment.text(msg.text)) msg_list.append(aiocqhttp.MessageSegment.text(msg.text))
elif type(msg) is yiri_message.Source: elif type(msg) is platform_message.Source:
msg_id = msg.id msg_id = msg.id
msg_time = msg.time msg_time = msg.time
elif type(msg) is mirai.Image: elif type(msg) is platform_message.Image:
arg = '' arg = ''
if msg.base64: if msg.base64:
arg = msg.base64 arg = msg.base64
@@ -40,13 +41,11 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
elif msg.path: elif msg.path:
arg = msg.path arg = msg.path
msg_list.append(aiocqhttp.MessageSegment.image(arg)) msg_list.append(aiocqhttp.MessageSegment.image(arg))
elif type(msg) is mirai.At: elif type(msg) is platform_message.At:
msg_list.append(aiocqhttp.MessageSegment.at(msg.target)) msg_list.append(aiocqhttp.MessageSegment.at(msg.target))
elif type(msg) is mirai.AtAll: elif type(msg) is platform_message.AtAll:
msg_list.append(aiocqhttp.MessageSegment.at("all")) msg_list.append(aiocqhttp.MessageSegment.at("all"))
elif type(msg) is mirai.Face: elif type(msg) is platform_message.Voice:
msg_list.append(aiocqhttp.MessageSegment.face(msg.face_id))
elif type(msg) is mirai.Voice:
arg = '' arg = ''
if msg.base64: if msg.base64:
arg = msg.base64 arg = msg.base64
@@ -74,25 +73,25 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
yiri_msg_list = [] yiri_msg_list = []
yiri_msg_list.append( yiri_msg_list.append(
yiri_message.Source(id=message_id, time=datetime.datetime.now()) platform_message.Source(id=message_id, time=datetime.datetime.now())
) )
for msg in message: for msg in message:
if msg.type == "at": if msg.type == "at":
if msg.data["qq"] == "all": if msg.data["qq"] == "all":
yiri_msg_list.append(yiri_message.AtAll()) yiri_msg_list.append(platform_message.AtAll())
else: else:
yiri_msg_list.append( yiri_msg_list.append(
yiri_message.At( platform_message.At(
target=msg.data["qq"], target=msg.data["qq"],
) )
) )
elif msg.type == "text": elif msg.type == "text":
yiri_msg_list.append(yiri_message.Plain(text=msg.data["text"])) yiri_msg_list.append(platform_message.Plain(text=msg.data["text"]))
elif msg.type == "image": elif msg.type == "image":
yiri_msg_list.append(yiri_message.Image(url=msg.data["url"])) yiri_msg_list.append(platform_message.Image(url=msg.data["url"]))
chain = mirai.MessageChain(yiri_msg_list) chain = platform_message.MessageChain(yiri_msg_list)
return chain return chain
@@ -100,11 +99,11 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
class AiocqhttpEventConverter(adapter.EventConverter): class AiocqhttpEventConverter(adapter.EventConverter):
@staticmethod @staticmethod
def yiri2target(event: mirai.Event, bot_account_id: int): def yiri2target(event: platform_events.Event, bot_account_id: int):
msg, msg_id, msg_time = AiocqhttpMessageConverter.yiri2target(event.message_chain) msg, msg_id, msg_time = AiocqhttpMessageConverter.yiri2target(event.message_chain)
if type(event) is mirai.GroupMessage: if type(event) is platform_events.GroupMessage:
role = "member" role = "member"
if event.sender.permission == "ADMINISTRATOR": if event.sender.permission == "ADMINISTRATOR":
@@ -140,7 +139,7 @@ class AiocqhttpEventConverter(adapter.EventConverter):
} }
return aiocqhttp.Event.from_payload(payload) return aiocqhttp.Event.from_payload(payload)
elif type(event) is mirai.FriendMessage: elif type(event) is platform_events.FriendMessage:
payload = { payload = {
"post_type": "message", "post_type": "message",
@@ -173,19 +172,20 @@ class AiocqhttpEventConverter(adapter.EventConverter):
if event.message_type == "group": if event.message_type == "group":
permission = "MEMBER" permission = "MEMBER"
if event.sender["role"] == "admin": if "role" in event.sender:
permission = "ADMINISTRATOR" if event.sender["role"] == "admin":
elif event.sender["role"] == "owner": permission = "ADMINISTRATOR"
permission = "OWNER" elif event.sender["role"] == "owner":
converted_event = mirai.GroupMessage( permission = "OWNER"
sender=mirai.models.entities.GroupMember( converted_event = platform_events.GroupMessage(
sender=platform_entities.GroupMember(
id=event.sender["user_id"], # message_seq 放哪? id=event.sender["user_id"], # message_seq 放哪?
member_name=event.sender["nickname"], member_name=event.sender["nickname"],
permission=permission, permission=permission,
group=mirai.models.entities.Group( group=platform_entities.Group(
id=event.group_id, id=event.group_id,
name=event.sender["nickname"], name=event.sender["nickname"],
permission=mirai.models.entities.Permission.Member, permission=platform_entities.Permission.Member,
), ),
special_title=event.sender["title"] if "title" in event.sender else "", special_title=event.sender["title"] if "title" in event.sender else "",
join_timestamp=0, join_timestamp=0,
@@ -197,8 +197,8 @@ class AiocqhttpEventConverter(adapter.EventConverter):
) )
return converted_event return converted_event
elif event.message_type == "private": elif event.message_type == "private":
return mirai.FriendMessage( return platform_events.FriendMessage(
sender=mirai.models.entities.Friend( sender=platform_entities.Friend(
id=event.sender["user_id"], id=event.sender["user_id"],
nickname=event.sender["nickname"], nickname=event.sender["nickname"],
remark="", remark="",
@@ -240,7 +240,7 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
self.bot = aiocqhttp.CQHttp() self.bot = aiocqhttp.CQHttp()
async def send_message( async def send_message(
self, target_type: str, target_id: str, message: mirai.MessageChain self, target_type: str, target_id: str, message: platform_message.MessageChain
): ):
aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0] aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
@@ -251,8 +251,8 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
async def reply_message( async def reply_message(
self, self,
message_source: mirai.MessageEvent, message_source: platform_events.MessageEvent,
message: mirai.MessageChain, message: platform_message.MessageChain,
quote_origin: bool = False, quote_origin: bool = False,
): ):
aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id) aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
@@ -270,8 +270,8 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[mirai.Event, adapter.MessageSourceAdapter], None], callback: typing.Callable[[platform_events.Event, adapter.MessageSourceAdapter], None],
): ):
async def on_message(event: aiocqhttp.Event): async def on_message(event: aiocqhttp.Event):
self.bot_account_id = event.self_id self.bot_account_id = event.self_id
@@ -280,15 +280,15 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
except: except:
traceback.print_exc() traceback.print_exc()
if event_type == mirai.GroupMessage: if event_type == platform_events.GroupMessage:
self.bot.on_message("group")(on_message) self.bot.on_message("group")(on_message)
elif event_type == mirai.FriendMessage: elif event_type == platform_events.FriendMessage:
self.bot.on_message("private")(on_message) self.bot.on_message("private")(on_message)
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[mirai.Event, adapter.MessageSourceAdapter], None], callback: typing.Callable[[platform_events.Event, adapter.MessageSourceAdapter], None],
): ):
return super().unregister_listener(event_type, callback) return super().unregister_listener(event_type, callback)

View File

@@ -6,26 +6,28 @@ import typing
import traceback import traceback
import logging import logging
import mirai
import nakuru import nakuru
import nakuru.entities.components as nkc import nakuru.entities.components as nkc
from .. import adapter as adapter_model from .. import adapter as adapter_model
from ...pipeline.longtext.strategies import forward from ...pipeline.longtext.strategies import forward
from ...platform.types import message as platform_message
from ...platform.types import entities as platform_entities
from ...platform.types import events as platform_events
class NakuruProjectMessageConverter(adapter_model.MessageConverter): class NakuruProjectMessageConverter(adapter_model.MessageConverter):
"""消息转换器""" """消息转换器"""
@staticmethod @staticmethod
def yiri2target(message_chain: mirai.MessageChain) -> list: def yiri2target(message_chain: platform_message.MessageChain) -> list:
msg_list = [] msg_list = []
if type(message_chain) is mirai.MessageChain: if type(message_chain) is platform_message.MessageChain:
msg_list = message_chain.__root__ msg_list = message_chain.__root__
elif type(message_chain) is list: elif type(message_chain) is list:
msg_list = message_chain msg_list = message_chain
elif type(message_chain) is str: elif type(message_chain) is str:
msg_list = [mirai.Plain(message_chain)] msg_list = [platform_message.Plain(message_chain)]
else: else:
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain))) raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))
@@ -33,22 +35,20 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
# 遍历并转换 # 遍历并转换
for component in msg_list: for component in msg_list:
if type(component) is mirai.Plain: if type(component) is platform_message.Plain:
nakuru_msg_list.append(nkc.Plain(component.text, False)) nakuru_msg_list.append(nkc.Plain(component.text, False))
elif type(component) is mirai.Image: elif type(component) is platform_message.Image:
if component.url is not None: if component.url is not None:
nakuru_msg_list.append(nkc.Image.fromURL(component.url)) nakuru_msg_list.append(nkc.Image.fromURL(component.url))
elif component.base64 is not None: elif component.base64 is not None:
nakuru_msg_list.append(nkc.Image.fromBase64(component.base64)) nakuru_msg_list.append(nkc.Image.fromBase64(component.base64))
elif component.path is not None: elif component.path is not None:
nakuru_msg_list.append(nkc.Image.fromFileSystem(component.path)) nakuru_msg_list.append(nkc.Image.fromFileSystem(component.path))
elif type(component) is mirai.Face: elif type(component) is platform_message.At:
nakuru_msg_list.append(nkc.Face(id=component.face_id))
elif type(component) is mirai.At:
nakuru_msg_list.append(nkc.At(qq=component.target)) nakuru_msg_list.append(nkc.At(qq=component.target))
elif type(component) is mirai.AtAll: elif type(component) is platform_message.AtAll:
nakuru_msg_list.append(nkc.AtAll()) nakuru_msg_list.append(nkc.AtAll())
elif type(component) is mirai.Voice: elif type(component) is platform_message.Voice:
if component.url is not None: if component.url is not None:
nakuru_msg_list.append(nkc.Record.fromURL(component.url)) nakuru_msg_list.append(nkc.Record.fromURL(component.url))
elif component.path is not None: elif component.path is not None:
@@ -80,49 +80,47 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
return nakuru_msg_list return nakuru_msg_list
@staticmethod @staticmethod
def target2yiri(message_chain: typing.Any, message_id: int = -1) -> mirai.MessageChain: def target2yiri(message_chain: typing.Any, message_id: int = -1) -> platform_message.MessageChain:
"""将Yiri的消息链转换为YiriMirai的消息链""" """将Yiri的消息链转换为YiriMirai的消息链"""
assert type(message_chain) is list assert type(message_chain) is list
yiri_msg_list = [] yiri_msg_list = []
import datetime import datetime
# 添加Source组件以标记message_id等信息 # 添加Source组件以标记message_id等信息
yiri_msg_list.append(mirai.models.message.Source(id=message_id, time=datetime.datetime.now())) yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
for component in message_chain: for component in message_chain:
if type(component) is nkc.Plain: if type(component) is nkc.Plain:
yiri_msg_list.append(mirai.Plain(text=component.text)) yiri_msg_list.append(platform_message.Plain(text=component.text))
elif type(component) is nkc.Image: elif type(component) is nkc.Image:
yiri_msg_list.append(mirai.Image(url=component.url)) yiri_msg_list.append(platform_message.Image(url=component.url))
elif type(component) is nkc.Face:
yiri_msg_list.append(mirai.Face(face_id=component.id))
elif type(component) is nkc.At: elif type(component) is nkc.At:
yiri_msg_list.append(mirai.At(target=component.qq)) yiri_msg_list.append(platform_message.At(target=component.qq))
elif type(component) is nkc.AtAll: elif type(component) is nkc.AtAll:
yiri_msg_list.append(mirai.AtAll()) yiri_msg_list.append(platform_message.AtAll())
else: else:
pass pass
# logging.debug("转换后的消息链: " + str(yiri_msg_list)) # logging.debug("转换后的消息链: " + str(yiri_msg_list))
chain = mirai.MessageChain(yiri_msg_list) chain = platform_message.MessageChain(yiri_msg_list)
return chain return chain
class NakuruProjectEventConverter(adapter_model.EventConverter): class NakuruProjectEventConverter(adapter_model.EventConverter):
"""事件转换器""" """事件转换器"""
@staticmethod @staticmethod
def yiri2target(event: typing.Type[mirai.Event]): def yiri2target(event: typing.Type[platform_events.Event]):
if event is mirai.GroupMessage: if event is platform_events.GroupMessage:
return nakuru.GroupMessage return nakuru.GroupMessage
elif event is mirai.FriendMessage: elif event is platform_events.FriendMessage:
return nakuru.FriendMessage return nakuru.FriendMessage
else: else:
raise Exception("未支持转换的事件类型: " + str(event)) raise Exception("未支持转换的事件类型: " + str(event))
@staticmethod @staticmethod
def target2yiri(event: typing.Any) -> mirai.Event: def target2yiri(event: typing.Any) -> platform_events.Event:
yiri_chain = NakuruProjectMessageConverter.target2yiri(event.message, event.message_id) yiri_chain = NakuruProjectMessageConverter.target2yiri(event.message, event.message_id)
if type(event) is nakuru.FriendMessage: # 私聊消息事件 if type(event) is nakuru.FriendMessage: # 私聊消息事件
return mirai.FriendMessage( return platform_events.FriendMessage(
sender=mirai.models.entities.Friend( sender=platform_entities.Friend(
id=event.sender.user_id, id=event.sender.user_id,
nickname=event.sender.nickname, nickname=event.sender.nickname,
remark=event.sender.nickname remark=event.sender.nickname
@@ -138,16 +136,15 @@ class NakuruProjectEventConverter(adapter_model.EventConverter):
elif event.sender.role == "owner": elif event.sender.role == "owner":
permission = "OWNER" permission = "OWNER"
import mirai.models.entities as entities return platform_events.GroupMessage(
return mirai.GroupMessage( sender=platform_entities.GroupMember(
sender=mirai.models.entities.GroupMember(
id=event.sender.user_id, id=event.sender.user_id,
member_name=event.sender.nickname, member_name=event.sender.nickname,
permission=permission, permission=permission,
group=mirai.models.entities.Group( group=platform_entities.Group(
id=event.group_id, id=event.group_id,
name=event.sender.nickname, name=event.sender.nickname,
permission=entities.Permission.Member permission=platform_entities.Permission.Member
), ),
special_title=event.sender.title, special_title=event.sender.title,
join_timestamp=0, join_timestamp=0,
@@ -189,7 +186,7 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
self, self,
target_type: str, target_type: str,
target_id: str, target_id: str,
message: typing.Union[mirai.MessageChain, list], message: typing.Union[platform_message.MessageChain, list],
converted: bool = False converted: bool = False
): ):
task = None task = None
@@ -222,8 +219,8 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
async def reply_message( async def reply_message(
self, self,
message_source: mirai.MessageEvent, message_source: platform_events.MessageEvent,
message: mirai.MessageChain, message: platform_message.MessageChain,
quote_origin: bool = False quote_origin: bool = False
): ):
message = self.message_converter.yiri2target(message) message = self.message_converter.yiri2target(message)
@@ -233,14 +230,14 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
id=message_source.message_chain.message_id, id=message_source.message_chain.message_id,
) )
) )
if type(message_source) is mirai.GroupMessage: if type(message_source) is platform_events.GroupMessage:
await self.send_message( await self.send_message(
"group", "group",
message_source.sender.group.id, message_source.sender.group.id,
message, message,
converted=True converted=True
) )
elif type(message_source) is mirai.FriendMessage: elif type(message_source) is platform_events.FriendMessage:
await self.send_message( await self.send_message(
"person", "person",
message_source.sender.id, message_source.sender.id,
@@ -258,8 +255,8 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] callback: typing.Callable[[platform_events.Event, adapter_model.MessageSourceAdapter], None]
): ):
try: try:
@@ -286,8 +283,8 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] callback: typing.Callable[[platform_events.Event, adapter_model.MessageSourceAdapter], None]
): ):
nakuru_event_name = self.event_converter.yiri2target(event_type).__name__ nakuru_event_name = self.event_converter.yiri2target(event_type).__name__
@@ -331,5 +328,5 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)
def kill(self) -> bool: async def kill(self) -> bool:
return False return False

View File

@@ -3,13 +3,9 @@ from __future__ import annotations
import logging import logging
import typing import typing
import datetime import datetime
import asyncio
import re import re
import traceback import traceback
import json
import threading
import mirai
import botpy import botpy
import botpy.message as botpy_message import botpy.message as botpy_message
import botpy.types.message as botpy_message_type import botpy.types.message as botpy_message_type
@@ -18,15 +14,20 @@ from .. import adapter as adapter_model
from ...pipeline.longtext.strategies import forward from ...pipeline.longtext.strategies import forward
from ...core import app from ...core import app
from ...config import manager as cfg_mgr from ...config import manager as cfg_mgr
from ...platform.types import entities as platform_entities
from ...platform.types import events as platform_events
from ...platform.types import message as platform_message
class OfficialGroupMessage(mirai.GroupMessage): class OfficialGroupMessage(platform_events.GroupMessage):
pass pass
class OfficialFriendMessage(platform_events.FriendMessage):
pass
event_handler_mapping = { event_handler_mapping = {
mirai.GroupMessage: ["on_at_message_create", "on_group_at_message_create"], platform_events.GroupMessage: ["on_at_message_create", "on_group_at_message_create"],
mirai.FriendMessage: ["on_direct_message_create"], platform_events.FriendMessage: ["on_direct_message_create", "on_c2c_message_create"],
} }
@@ -122,16 +123,16 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
"""QQ 官方消息转换器""" """QQ 官方消息转换器"""
@staticmethod @staticmethod
def yiri2target(message_chain: mirai.MessageChain): def yiri2target(message_chain: platform_message.MessageChain):
"""将 YiriMirai 的消息链转换为 QQ 官方消息""" """将 YiriMirai 的消息链转换为 QQ 官方消息"""
msg_list = [] msg_list = []
if type(message_chain) is mirai.MessageChain: if type(message_chain) is platform_message.MessageChain:
msg_list = message_chain.__root__ msg_list = message_chain.__root__
elif type(message_chain) is list: elif type(message_chain) is list:
msg_list = message_chain msg_list = message_chain
elif type(message_chain) is str: elif type(message_chain) is str:
msg_list = [mirai.Plain(text=message_chain)] msg_list = [platform_message.Plain(text=message_chain)]
else: else:
raise Exception( raise Exception(
"Unknown message type: " + str(message_chain) + str(type(message_chain)) "Unknown message type: " + str(message_chain) + str(type(message_chain))
@@ -152,22 +153,22 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
# 遍历并转换 # 遍历并转换
for component in msg_list: for component in msg_list:
if type(component) is mirai.Plain: if type(component) is platform_message.Plain:
offcial_messages.append({"type": "text", "content": component.text}) offcial_messages.append({"type": "text", "content": component.text})
elif type(component) is mirai.Image: elif type(component) is platform_message.Image:
if component.url is not None: if component.url is not None:
offcial_messages.append({"type": "image", "content": component.url}) offcial_messages.append({"type": "image", "content": component.url})
elif component.path is not None: elif component.path is not None:
offcial_messages.append( offcial_messages.append(
{"type": "file_image", "content": component.path} {"type": "file_image", "content": component.path}
) )
elif type(component) is mirai.At: elif type(component) is platform_message.At:
offcial_messages.append({"type": "at", "content": ""}) offcial_messages.append({"type": "at", "content": ""})
elif type(component) is mirai.AtAll: elif type(component) is platform_message.AtAll:
print( print(
"上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。" "上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
) )
elif type(component) is mirai.Voice: elif type(component) is platform_message.Voice:
print( print(
"上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。" "上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
) )
@@ -193,32 +194,32 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
@staticmethod @staticmethod
def extract_message_chain_from_obj( def extract_message_chain_from_obj(
message: typing.Union[botpy_message.Message, botpy_message.DirectMessage], message: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, botpy_message.C2CMessage],
message_id: str = None, message_id: str = None,
bot_account_id: int = 0, bot_account_id: int = 0,
) -> mirai.MessageChain: ) -> platform_message.MessageChain:
yiri_msg_list = [] yiri_msg_list = []
# 存id # 存id
yiri_msg_list.append( yiri_msg_list.append(
mirai.models.message.Source( platform_message.Source(
id=save_msg_id(message_id), time=datetime.datetime.now() id=save_msg_id(message_id), time=datetime.datetime.now()
) )
) )
if type(message) is not botpy_message.DirectMessage: if type(message) not in [botpy_message.DirectMessage, botpy_message.C2CMessage]:
yiri_msg_list.append(mirai.At(target=bot_account_id)) yiri_msg_list.append(platform_message.At(target=bot_account_id))
if hasattr(message, "mentions"): if hasattr(message, "mentions"):
for mention in message.mentions: for mention in message.mentions:
if mention.bot: if mention.bot:
continue continue
yiri_msg_list.append(mirai.At(target=mention.id)) yiri_msg_list.append(platform_message.At(target=mention.id))
for attachment in message.attachments: for attachment in message.attachments:
if attachment.content_type.startswith("image"): if attachment.content_type.startswith("image"):
yiri_msg_list.append(mirai.Image(url=attachment.url)) yiri_msg_list.append(platform_message.Image(url=attachment.url))
else: else:
logging.warning( logging.warning(
"不支持的附件类型:" + attachment.content_type + ",忽略此附件。" "不支持的附件类型:" + attachment.content_type + ",忽略此附件。"
@@ -226,9 +227,9 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
content = re.sub(r"<@!\d+>", "", str(message.content)) content = re.sub(r"<@!\d+>", "", str(message.content))
if content.strip() != "": if content.strip() != "":
yiri_msg_list.append(mirai.Plain(text=content)) yiri_msg_list.append(platform_message.Plain(text=content))
chain = mirai.MessageChain(yiri_msg_list) chain = platform_message.MessageChain(yiri_msg_list)
return chain return chain
@@ -243,10 +244,10 @@ class OfficialEventConverter(adapter_model.EventConverter):
self.member_openid_mapping = member_openid_mapping self.member_openid_mapping = member_openid_mapping
self.group_openid_mapping = group_openid_mapping self.group_openid_mapping = group_openid_mapping
def yiri2target(self, event: typing.Type[mirai.Event]): def yiri2target(self, event: typing.Type[platform_events.Event]):
if event == mirai.GroupMessage: if event == platform_events.GroupMessage:
return botpy_message.Message return botpy_message.Message
elif event == mirai.FriendMessage: elif event == platform_events.FriendMessage:
return botpy_message.DirectMessage return botpy_message.DirectMessage
else: else:
raise Exception( raise Exception(
@@ -255,9 +256,8 @@ class OfficialEventConverter(adapter_model.EventConverter):
def target2yiri( def target2yiri(
self, self,
event: typing.Union[botpy_message.Message, botpy_message.DirectMessage] event: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, botpy_message.C2CMessage],
) -> mirai.Event: ) -> platform_events.Event:
import mirai.models.entities as mirai_entities
if type(event) == botpy_message.Message: # 频道内,转群聊事件 if type(event) == botpy_message.Message: # 频道内,转群聊事件
permission = "MEMBER" permission = "MEMBER"
@@ -267,15 +267,15 @@ class OfficialEventConverter(adapter_model.EventConverter):
elif "4" in event.member.roles: elif "4" in event.member.roles:
permission = "OWNER" permission = "OWNER"
return mirai.GroupMessage( return platform_events.GroupMessage(
sender=mirai_entities.GroupMember( sender=platform_entities.GroupMember(
id=event.author.id, id=event.author.id,
member_name=event.author.username, member_name=event.author.username,
permission=permission, permission=permission,
group=mirai_entities.Group( group=platform_entities.Group(
id=event.channel_id, id=event.channel_id,
name=event.author.username, name=event.author.username,
permission=mirai_entities.Permission.Member, permission=platform_entities.Permission.Member,
), ),
special_title="", special_title="",
join_timestamp=int( join_timestamp=int(
@@ -295,9 +295,9 @@ class OfficialEventConverter(adapter_model.EventConverter):
).timestamp() ).timestamp()
), ),
) )
elif type(event) == botpy_message.DirectMessage: # 私聊,转私聊事件 elif type(event) == botpy_message.DirectMessage: # 频道私聊,转私聊事件
return mirai.FriendMessage( return platform_events.FriendMessage(
sender=mirai_entities.Friend( sender=platform_entities.Friend(
id=event.guild_id, id=event.guild_id,
nickname=event.author.username, nickname=event.author.username,
remark=event.author.username, remark=event.author.username,
@@ -311,19 +311,19 @@ class OfficialEventConverter(adapter_model.EventConverter):
).timestamp() ).timestamp()
), ),
) )
elif type(event) == botpy_message.GroupMessage: elif type(event) == botpy_message.GroupMessage: # 群聊,转群聊事件
replacing_member_id = self.member_openid_mapping.save_openid(event.author.member_openid) replacing_member_id = self.member_openid_mapping.save_openid(event.author.member_openid)
return OfficialGroupMessage( return OfficialGroupMessage(
sender=mirai_entities.GroupMember( sender=platform_entities.GroupMember(
id=replacing_member_id, id=replacing_member_id,
member_name=replacing_member_id, member_name=replacing_member_id,
permission="MEMBER", permission="MEMBER",
group=mirai_entities.Group( group=platform_entities.Group(
id=self.group_openid_mapping.save_openid(event.group_openid), id=self.group_openid_mapping.save_openid(event.group_openid),
name=replacing_member_id, name=replacing_member_id,
permission=mirai_entities.Permission.Member, permission=platform_entities.Permission.Member,
), ),
special_title="", special_title="",
join_timestamp=int(0), join_timestamp=int(0),
@@ -339,6 +339,25 @@ class OfficialEventConverter(adapter_model.EventConverter):
).timestamp() ).timestamp()
), ),
) )
elif type(event) == botpy_message.C2CMessage: # 私聊,转私聊事件
user_id_alter = self.member_openid_mapping.save_openid(event.author.user_openid) # 实测这里的user_openid与group的member_openid是一样的
return OfficialFriendMessage(
sender=platform_entities.Friend(
id=user_id_alter,
nickname=user_id_alter,
remark=user_id_alter,
),
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
event, event.id
),
time=int(
datetime.datetime.strptime(
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
).timestamp()
),
)
@adapter_model.adapter_class("qq-botpy") @adapter_model.adapter_class("qq-botpy")
@@ -368,6 +387,7 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
group_openid_mapping: OpenIDMapping[str, int] = None group_openid_mapping: OpenIDMapping[str, int] = None
group_msg_seq = None group_msg_seq = None
c2c_msg_seq = None
def __init__(self, cfg: dict, ap: app.Application): def __init__(self, cfg: dict, ap: app.Application):
"""初始化适配器""" """初始化适配器"""
@@ -375,6 +395,7 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
self.ap = ap self.ap = ap
self.group_msg_seq = 1 self.group_msg_seq = 1
self.c2c_msg_seq = 1
switchs = {} switchs = {}
@@ -388,7 +409,7 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
self.bot = botpy.Client(intents=intents) self.bot = botpy.Client(intents=intents)
async def send_message( async def send_message(
self, target_type: str, target_id: str, message: mirai.MessageChain self, target_type: str, target_id: str, message: platform_message.MessageChain
): ):
message_list = self.message_converter.yiri2target(message) message_list = self.message_converter.yiri2target(message)
@@ -415,8 +436,8 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
async def reply_message( async def reply_message(
self, self,
message_source: mirai.MessageEvent, message_source: platform_events.MessageEvent,
message: mirai.MessageChain, message: platform_message.MessageChain,
quote_origin: bool = False, quote_origin: bool = False,
): ):
@@ -441,40 +462,80 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
] ]
) )
if type(message_source) == mirai.GroupMessage: if type(message_source) == platform_events.GroupMessage:
args["channel_id"] = str(message_source.sender.group.id) args["channel_id"] = str(message_source.sender.group.id)
args["msg_id"] = cached_message_ids[ args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id) str(message_source.message_chain.message_id)
] ]
await self.bot.api.post_message(**args) await self.bot.api.post_message(**args)
elif type(message_source) == mirai.FriendMessage: elif type(message_source) == platform_events.FriendMessage:
args["guild_id"] = str(message_source.sender.id) args["guild_id"] = str(message_source.sender.id)
args["msg_id"] = cached_message_ids[ args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id) str(message_source.message_chain.message_id)
] ]
await self.bot.api.post_dms(**args) await self.bot.api.post_dms(**args)
elif type(message_source) == OfficialGroupMessage: elif type(message_source) == OfficialGroupMessage:
if "image" in args or "file_image" in args:
if "file_image" in args: # 暂不支持发送文件图片
continue continue
args["group_openid"] = self.group_openid_mapping.getkey( args["group_openid"] = self.group_openid_mapping.getkey(
message_source.sender.group.id message_source.sender.group.id
) )
if "image" in args:
uploadMedia = await self.bot.api.post_group_file(
group_openid=args["group_openid"],
file_type=1,
url=str(args['image'])
)
del args['image']
args['media'] = uploadMedia
args['msg_type'] = 7
args["msg_id"] = cached_message_ids[ args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id) str(message_source.message_chain.message_id)
] ]
args["msg_seq"] = self.group_msg_seq args["msg_seq"] = self.group_msg_seq
self.group_msg_seq += 1 self.group_msg_seq += 1
await self.bot.api.post_group_message(**args) await self.bot.api.post_group_message(**args)
elif type(message_source) == OfficialFriendMessage:
if "file_image" in args:
continue
args["openid"] = self.member_openid_mapping.getkey(
message_source.sender.id
)
if "image" in args:
uploadMedia = await self.bot.api.post_c2c_file(
openid=args["openid"],
file_type=1,
url=str(args['image'])
)
del args['image']
args['media'] = uploadMedia
args['msg_type'] = 7
args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id)
]
args["msg_seq"] = self.c2c_msg_seq
self.c2c_msg_seq += 1
await self.bot.api.post_c2c_message(**args)
async def is_muted(self, group_id: int) -> bool: async def is_muted(self, group_id: int) -> bool:
return False return False
def register_listener( def register_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[ callback: typing.Callable[
[mirai.Event, adapter_model.MessageSourceAdapter], None [platform_events.Event, adapter_model.MessageSourceAdapter], None
], ],
): ):
@@ -498,9 +559,9 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[ callback: typing.Callable[
[mirai.Event, adapter_model.MessageSourceAdapter], None [platform_events.Event, adapter_model.MessageSourceAdapter], None
], ],
): ):
delattr(self.bot, event_handler_mapping[event_type]) delattr(self.bot, event_handler_mapping[event_type])
@@ -524,8 +585,12 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
self.member_openid_mapping, self.group_openid_mapping self.member_openid_mapping, self.group_openid_mapping
) )
self.ap.logger.info("运行 QQ 官方适配器") self.cfg['ret_coro'] = True
await self.bot.start(**self.cfg)
def kill(self) -> bool: self.ap.logger.info("运行 QQ 官方适配器")
return False await (await self.bot.start(**self.cfg))
async def kill(self) -> bool:
if not self.bot.is_closed():
await self.bot.close()
return True

View File

@@ -1,124 +0,0 @@
import asyncio
import typing
import mirai
import mirai.models.bus
from mirai.bot import MiraiRunner
from .. import adapter as adapter_model
from ...core import app
@adapter_model.adapter_class("yiri-mirai")
class YiriMiraiAdapter(adapter_model.MessageSourceAdapter):
"""YiriMirai适配器"""
bot: mirai.Mirai
def __init__(self, config: dict, ap: app.Application):
"""初始化YiriMirai的对象"""
self.ap = ap
self.config = config
if 'adapter' not in config or \
config['adapter'] == 'WebSocketAdapter':
self.bot = mirai.Mirai(
qq=config['qq'],
adapter=mirai.WebSocketAdapter(
host=config['host'],
port=config['port'],
verify_key=config['verifyKey']
)
)
elif config['adapter'] == 'HTTPAdapter':
self.bot = mirai.Mirai(
qq=config['qq'],
adapter=mirai.HTTPAdapter(
host=config['host'],
port=config['port'],
verify_key=config['verifyKey']
)
)
else:
raise Exception('Unknown adapter for YiriMirai: ' + config['adapter'])
async def send_message(
self,
target_type: str,
target_id: str,
message: mirai.MessageChain
):
"""发送消息
Args:
target_type (str): 目标类型,`person`或`group`
target_id (str): 目标ID
message (mirai.MessageChain): YiriMirai库的消息链
"""
task = None
if target_type == 'person':
task = self.bot.send_friend_message(int(target_id), message)
elif target_type == 'group':
task = self.bot.send_group_message(int(target_id), message)
else:
raise Exception('Unknown target type: ' + target_type)
await task
async def reply_message(
self,
message_source: mirai.MessageEvent,
message: mirai.MessageChain,
quote_origin: bool = False
):
"""回复消息
Args:
message_source (mirai.MessageEvent): YiriMirai消息源事件
message (mirai.MessageChain): YiriMirai库的消息链
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
"""
await self.bot.send(message_source, message, quote_origin)
async def is_muted(self, group_id: int) -> bool:
result = await self.bot.member_info(target=group_id, member_id=self.bot.qq).get()
if result.mute_time_remaining > 0:
return True
return False
def register_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
):
"""注册事件监听器
Args:
event_type (typing.Type[mirai.Event]): YiriMirai事件类型
callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数为YiriMirai事件
"""
async def wrapper(event: mirai.Event):
await callback(event, self)
self.bot.on(event_type)(wrapper)
def unregister_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
):
"""注销事件监听器
Args:
event_type (typing.Type[mirai.Event]): YiriMirai事件类型
callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数为YiriMirai事件
"""
assert isinstance(self.bot, mirai.Mirai)
bus = self.bot.bus
assert isinstance(bus, mirai.models.bus.ModelEventBus)
bus.unsubscribe(event_type, callback)
async def run_async(self):
self.bot_account_id = self.bot.qq
return await MiraiRunner(self.bot)._run()
async def kill(self) -> bool:
return False

View File

@@ -0,0 +1,3 @@
from .entities import *
from .events import *
from .message import *

105
pkg/platform/types/base.py Normal file
View File

@@ -0,0 +1,105 @@
from typing import Dict, List, Type
import pydantic.v1.main as pdm
from pydantic.v1 import BaseModel
class PlatformMetaclass(pdm.ModelMetaclass):
"""此类是平台中使用的 pydantic 模型的元类的基类。"""
def to_camel(name: str) -> str:
"""将下划线命名风格转换为小驼峰命名。"""
if name[:2] == '__': # 不处理双下划线开头的特殊命名。
return name
name_parts = name.split('_')
return ''.join(name_parts[:1] + [x.title() for x in name_parts[1:]])
class PlatformBaseModel(BaseModel, metaclass=PlatformMetaclass):
"""模型基类。
启用了三项配置:
1. 允许解析时传入额外的值,并将额外值保存在模型中。
2. 允许通过别名访问字段。
3. 自动生成小驼峰风格的别名。
"""
def __init__(self, *args, **kwargs):
""""""
super().__init__(*args, **kwargs)
def __repr__(self) -> str:
return self.__class__.__name__ + '(' + ', '.join(
(f'{k}={repr(v)}' for k, v in self.__dict__.items() if v)
) + ')'
class Config:
extra = 'allow'
allow_population_by_field_name = True
alias_generator = to_camel
class PlatformIndexedMetaclass(PlatformMetaclass):
"""可以通过子类名获取子类的类的元类。"""
__indexedbases__: List[Type['PlatformIndexedModel']] = []
__indexedmodel__ = None
def __new__(cls, name, bases, attrs, **kwargs):
new_cls = super().__new__(cls, name, bases, attrs, **kwargs)
# 第一类PlatformIndexedModel
if name == 'PlatformIndexedModel':
cls.__indexedmodel__ = new_cls
new_cls.__indexes__ = {}
return new_cls
# 第二类PlatformIndexedModel 的直接子类,这些是可以通过子类名获取子类的类。
if cls.__indexedmodel__ in bases:
cls.__indexedbases__.append(new_cls)
new_cls.__indexes__ = {}
return new_cls
# 第三类PlatformIndexedModel 的直接子类的子类,这些添加到直接子类的索引中。
for base in cls.__indexedbases__:
if issubclass(new_cls, base):
base.__indexes__[name] = new_cls
return new_cls
def __getitem__(cls, name):
return cls.get_subtype(name)
class PlatformIndexedModel(PlatformBaseModel, metaclass=PlatformIndexedMetaclass):
"""可以通过子类名获取子类的类。"""
__indexes__: Dict[str, Type['PlatformIndexedModel']]
@classmethod
def get_subtype(cls, name: str) -> Type['PlatformIndexedModel']:
"""根据类名称,获取相应的子类类型。
Args:
name: 类名称。
Returns:
Type['PlatformIndexedModel']: 子类类型。
"""
try:
type_ = cls.__indexes__.get(name)
if not (type_ and issubclass(type_, cls)):
raise ValueError(f'`{name}` 不是 `{cls.__name__}` 的子类!')
return type_
except AttributeError as e:
raise ValueError(f'`{name}` 不是 `{cls.__name__}` 的子类!') from None
@classmethod
def parse_subtype(cls, obj: dict) -> 'PlatformIndexedModel':
"""通过字典,构造对应的模型对象。
Args:
obj: 一个字典,包含了模型对象的属性。
Returns:
PlatformIndexedModel: 构造的对象。
"""
if cls in PlatformIndexedModel.__subclasses__():
ModelType = cls.get_subtype(obj['type'])
return ModelType.parse_obj(obj)
return super().parse_obj(obj)

View File

@@ -0,0 +1,143 @@
# -*- coding: utf-8 -*-
"""
此模块提供实体和配置项模型。
"""
import abc
from datetime import datetime
from enum import Enum
import typing
import pydantic.v1 as pydantic
class Entity(pydantic.BaseModel):
"""实体,表示一个用户或群。"""
id: int
"""QQ 号或群号。"""
@abc.abstractmethod
def get_avatar_url(self) -> str:
"""头像图片链接。"""
@abc.abstractmethod
def get_name(self) -> str:
"""名称。"""
class Friend(Entity):
"""好友。"""
id: int
"""QQ 号。"""
nickname: typing.Optional[str]
"""昵称。"""
remark: typing.Optional[str]
"""备注。"""
def get_avatar_url(self) -> str:
return f'http://q4.qlogo.cn/g?b=qq&nk={self.id}&s=140'
def get_name(self) -> str:
return self.nickname or self.remark or ''
class Permission(str, Enum):
"""群成员身份权限。"""
Member = "MEMBER"
"""成员。"""
Administrator = "ADMINISTRATOR"
"""管理员。"""
Owner = "OWNER"
"""群主。"""
def __repr__(self) -> str:
return repr(self.value)
class Group(Entity):
"""群。"""
id: int
"""群号。"""
name: str
"""群名称。"""
permission: Permission
"""Bot 在群中的权限。"""
def get_avatar_url(self) -> str:
return f'https://p.qlogo.cn/gh/{self.id}/{self.id}/'
def get_name(self) -> str:
return self.name
class GroupMember(Entity):
"""群成员。"""
id: int
"""QQ 号。"""
member_name: str
"""群成员名称。"""
permission: Permission
"""Bot 在群中的权限。"""
group: Group
"""群。"""
special_title: str = ''
"""群头衔。"""
join_timestamp: datetime = datetime.utcfromtimestamp(0)
"""加入群的时间。"""
last_speak_timestamp: datetime = datetime.utcfromtimestamp(0)
"""最后一次发言的时间。"""
mute_time_remaining: int = 0
"""禁言剩余时间。"""
def get_avatar_url(self) -> str:
return f'http://q4.qlogo.cn/g?b=qq&nk={self.id}&s=140'
def get_name(self) -> str:
return self.member_name
class Client(Entity):
"""来自其他客户端的用户。"""
id: int
"""识别 id。"""
platform: str
"""来源平台。"""
def get_avatar_url(self) -> str:
raise NotImplementedError
def get_name(self) -> str:
return self.platform
class Subject(pydantic.BaseModel):
"""另一种实体类型表示。"""
id: int
"""QQ 号或群号。"""
kind: typing.Literal['Friend', 'Group', 'Stranger']
"""类型。"""
class Config(pydantic.BaseModel):
"""配置项类型。"""
def modify(self, **kwargs) -> 'Config':
"""修改部分设置。"""
for k, v in kwargs.items():
if k in self.__fields__:
setattr(self, k, v)
else:
raise ValueError(f'未知配置项: {k}')
return self
class GroupConfigModel(Config):
"""群配置。"""
name: str
"""群名称。"""
confess_talk: bool
"""是否允许坦白说。"""
allow_member_invite: bool
"""是否允许成员邀请好友入群。"""
auto_approve: bool
"""是否开启自动审批入群。"""
anonymous_chat: bool
"""是否开启匿名聊天。"""
announcement: str = ''
"""群公告。"""
class MemberInfoModel(Config, GroupMember):
"""群成员信息。"""

View File

@@ -0,0 +1,124 @@
# -*- coding: utf-8 -*-
"""
此模块提供事件模型。
"""
from datetime import datetime
from enum import Enum
import typing
import pydantic.v1 as pydantic
from . import entities as platform_entities
from . import message as platform_message
class Event(pydantic.BaseModel):
"""事件基类。
Args:
type: 事件名。
"""
type: str
"""事件名。"""
def __repr__(self):
return self.__class__.__name__ + '(' + ', '.join(
(
f'{k}={repr(v)}'
for k, v in self.__dict__.items() if k != 'type' and v
)
) + ')'
@classmethod
def parse_subtype(cls, obj: dict) -> 'Event':
try:
return typing.cast(Event, super().parse_subtype(obj))
except ValueError:
return Event(type=obj['type'])
@classmethod
def get_subtype(cls, name: str) -> typing.Type['Event']:
try:
return typing.cast(typing.Type[Event], super().get_subtype(name))
except ValueError:
return Event
###############################
# Bot Event
class BotEvent(Event):
"""Bot 自身事件。
Args:
type: 事件名。
qq: Bot 的 QQ 号。
"""
type: str
"""事件名。"""
qq: int
"""Bot 的 QQ 号。"""
###############################
# Message Event
class MessageEvent(Event):
"""消息事件。
Args:
type: 事件名。
message_chain: 消息内容。
"""
type: str
"""事件名。"""
message_chain: platform_message.MessageChain
"""消息内容。"""
class FriendMessage(MessageEvent):
"""好友消息。
Args:
type: 事件名。
sender: 发送消息的好友。
message_chain: 消息内容。
"""
type: str = 'FriendMessage'
"""事件名。"""
sender: platform_entities.Friend
"""发送消息的好友。"""
message_chain: platform_message.MessageChain
"""消息内容。"""
class GroupMessage(MessageEvent):
"""群消息。
Args:
type: 事件名。
sender: 发送消息的群成员。
message_chain: 消息内容。
"""
type: str = 'GroupMessage'
"""事件名。"""
sender: platform_entities.GroupMember
"""发送消息的群成员。"""
message_chain: platform_message.MessageChain
"""消息内容。"""
@property
def group(self) -> platform_entities.Group:
return self.sender.group
class StrangerMessage(MessageEvent):
"""陌生人消息。
Args:
type: 事件名。
sender: 发送消息的人。
message_chain: 消息内容。
"""
type: str = 'StrangerMessage'
"""事件名。"""
sender: platform_entities.Friend
"""发送消息的人。"""
message_chain: platform_message.MessageChain
"""消息内容。"""

View File

@@ -0,0 +1,816 @@
import itertools
import logging
from datetime import datetime
from enum import Enum
from pathlib import Path
import typing
import pydantic.v1 as pydantic
from . import entities as platform_entities
from .base import PlatformBaseModel, PlatformIndexedMetaclass, PlatformIndexedModel
logger = logging.getLogger(__name__)
class MessageComponentMetaclass(PlatformIndexedMetaclass):
"""消息组件元类。"""
__message_component__ = None
def __new__(cls, name, bases, attrs, **kwargs):
new_cls = super().__new__(cls, name, bases, attrs, **kwargs)
if name == 'MessageComponent':
cls.__message_component__ = new_cls
if not cls.__message_component__:
return new_cls
for base in bases:
if issubclass(base, cls.__message_component__):
# 获取字段名
if hasattr(new_cls, '__fields__'):
# 忽略 type 字段
new_cls.__parameter_names__ = list(new_cls.__fields__)[1:]
else:
new_cls.__parameter_names__ = []
break
return new_cls
class MessageComponent(PlatformIndexedModel, metaclass=MessageComponentMetaclass):
"""消息组件。"""
type: str
"""消息组件类型。"""
def __str__(self):
return ''
def __repr__(self):
return self.__class__.__name__ + '(' + ', '.join(
(
f'{k}={repr(v)}'
for k, v in self.__dict__.items() if k != 'type' and v
)
) + ')'
def __init__(self, *args, **kwargs):
# 解析参数列表,将位置参数转化为具名参数
parameter_names = self.__parameter_names__
if len(args) > len(parameter_names):
raise TypeError(
f'`{self.type}`需要{len(parameter_names)}个参数,但传入了{len(args)}个。'
)
for name, value in zip(parameter_names, args):
if name in kwargs:
raise TypeError(f'在 `{self.type}` 中,具名参数 `{name}` 与位置参数重复。')
kwargs[name] = value
super().__init__(**kwargs)
TMessageComponent = typing.TypeVar('TMessageComponent', bound=MessageComponent)
class MessageChain(PlatformBaseModel):
"""消息链。
一个构造消息链的例子:
```py
message_chain = MessageChain([
AtAll(),
Plain("Hello World!"),
])
```
`Plain` 可以省略。
```py
message_chain = MessageChain([
AtAll(),
"Hello World!",
])
```
在调用 API 时,参数中需要 MessageChain 的,也可以使用 `List[MessageComponent]` 代替。
例如,以下两种写法是等价的:
```py
await bot.send_friend_message(12345678, [
Plain("Hello World!")
])
```
```py
await bot.send_friend_message(12345678, MessageChain([
Plain("Hello World!")
]))
```
可以使用 `in` 运算检查消息链中:
1. 是否有某个消息组件。
2. 是否有某个类型的消息组件。
```py
if AtAll in message_chain:
print('AtAll')
if At(bot.qq) in message_chain:
print('At Me')
```
消息链对索引操作进行了增强。以消息组件类型为索引,获取消息链中的全部该类型的消息组件。
```py
plain_list = message_chain[Plain]
'[Plain("Hello World!")]'
```
可以用加号连接两个消息链。
```py
MessageChain(['Hello World!']) + MessageChain(['Goodbye World!'])
# 返回 MessageChain([Plain("Hello World!"), Plain("Goodbye World!")])
```
"""
__root__: typing.List[MessageComponent]
@staticmethod
def _parse_message_chain(msg_chain: typing.Iterable):
result = []
for msg in msg_chain:
if isinstance(msg, dict):
result.append(MessageComponent.parse_subtype(msg))
elif isinstance(msg, MessageComponent):
result.append(msg)
elif isinstance(msg, str):
result.append(Plain(msg))
else:
raise TypeError(
f"消息链中元素需为 dict 或 str 或 MessageComponent当前类型{type(msg)}"
)
return result
@pydantic.validator('__root__', always=True, pre=True)
def _parse_component(cls, msg_chain):
if isinstance(msg_chain, (str, MessageComponent)):
msg_chain = [msg_chain]
if not msg_chain:
msg_chain = []
return cls._parse_message_chain(msg_chain)
@classmethod
def parse_obj(cls, msg_chain: typing.Iterable):
"""通过列表形式的消息链,构造对应的 `MessageChain` 对象。
Args:
msg_chain: 列表形式的消息链。
"""
result = cls._parse_message_chain(msg_chain)
return cls(__root__=result)
def __init__(self, __root__: typing.Iterable[MessageComponent] = None):
super().__init__(__root__=__root__)
def __str__(self):
return "".join(str(component) for component in self.__root__)
def __repr__(self):
return f'{self.__class__.__name__}({self.__root__!r})'
def __iter__(self):
yield from self.__root__
def get_first(self,
t: typing.Type[TMessageComponent]) -> typing.Optional[TMessageComponent]:
"""获取消息链中第一个符合类型的消息组件。"""
for component in self:
if isinstance(component, t):
return component
return None
@typing.overload
def __getitem__(self, index: int) -> MessageComponent:
...
@typing.overload
def __getitem__(self, index: slice) -> typing.List[MessageComponent]:
...
@typing.overload
def __getitem__(self,
index: typing.Type[TMessageComponent]) -> typing.List[TMessageComponent]:
...
@typing.overload
def __getitem__(
self, index: typing.Tuple[typing.Type[TMessageComponent], int]
) -> typing.List[TMessageComponent]:
...
def __getitem__(
self, index: typing.Union[int, slice, typing.Type[TMessageComponent],
typing.Tuple[typing.Type[TMessageComponent], int]]
) -> typing.Union[MessageComponent, typing.List[MessageComponent],
typing.List[TMessageComponent]]:
return self.get(index)
def __setitem__(
self, key: typing.Union[int, slice],
value: typing.Union[MessageComponent, str, typing.Iterable[typing.Union[MessageComponent,
str]]]
):
if isinstance(value, str):
value = Plain(value)
if isinstance(value, typing.Iterable):
value = (Plain(c) if isinstance(c, str) else c for c in value)
self.__root__[key] = value # type: ignore
def __delitem__(self, key: typing.Union[int, slice]):
del self.__root__[key]
def __reversed__(self) -> typing.Iterable[MessageComponent]:
return reversed(self.__root__)
def has(
self, sub: typing.Union[MessageComponent, typing.Type[MessageComponent],
'MessageChain', str]
) -> bool:
"""判断消息链中:
1. 是否有某个消息组件。
2. 是否有某个类型的消息组件。
Args:
sub (`Union[MessageComponent, Type[MessageComponent], 'MessageChain', str]`):
若为 `MessageComponent`,则判断该组件是否在消息链中。
若为 `Type[MessageComponent]`,则判断该组件类型是否在消息链中。
Returns:
bool: 是否找到。
"""
if isinstance(sub, type): # 检测消息链中是否有某种类型的对象
for i in self:
if type(i) is sub:
return True
return False
if isinstance(sub, MessageComponent): # 检查消息链中是否有某个组件
for i in self:
if i == sub:
return True
return False
raise TypeError(f"类型不匹配,当前类型:{type(sub)}")
def __contains__(self, sub) -> bool:
return self.has(sub)
def __ge__(self, other):
return other in self
def __len__(self) -> int:
return len(self.__root__)
def __add__(
self, other: typing.Union['MessageChain', MessageComponent, str]
) -> 'MessageChain':
if isinstance(other, MessageChain):
return self.__class__(self.__root__ + other.__root__)
if isinstance(other, str):
return self.__class__(self.__root__ + [Plain(other)])
if isinstance(other, MessageComponent):
return self.__class__(self.__root__ + [other])
return NotImplemented
def __radd__(self, other: typing.Union[MessageComponent, str]) -> 'MessageChain':
if isinstance(other, MessageComponent):
return self.__class__([other] + self.__root__)
if isinstance(other, str):
return self.__class__(
[typing.cast(MessageComponent, Plain(other))] + self.__root__
)
return NotImplemented
def __mul__(self, other: int):
if isinstance(other, int):
return self.__class__(self.__root__ * other)
return NotImplemented
def __rmul__(self, other: int):
return self.__mul__(other)
def __iadd__(self, other: typing.Iterable[typing.Union[MessageComponent, str]]):
self.extend(other)
def __imul__(self, other: int):
if isinstance(other, int):
self.__root__ *= other
return NotImplemented
def index(
self,
x: typing.Union[MessageComponent, typing.Type[MessageComponent]],
i: int = 0,
j: int = -1
) -> int:
"""返回 x 在消息链中首次出现项的索引号(索引号在 i 或其后且在 j 之前)。
Args:
x (`Union[MessageComponent, Type[MessageComponent]]`):
要查找的消息元素或消息元素类型。
i: 从哪个位置开始查找。
j: 查找到哪个位置结束。
Returns:
int: 如果找到,则返回索引号。
Raises:
ValueError: 没有找到。
TypeError: 类型不匹配。
"""
if isinstance(x, type):
l = len(self)
if i < 0:
i += l
if i < 0:
i = 0
if j < 0:
j += l
if j > l:
j = l
for index in range(i, j):
if type(self[index]) is x:
return index
raise ValueError("消息链中不存在该类型的组件。")
if isinstance(x, MessageComponent):
return self.__root__.index(x, i, j)
raise TypeError(f"类型不匹配,当前类型:{type(x)}")
def count(self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]]) -> int:
"""返回消息链中 x 出现的次数。
Args:
x (`Union[MessageComponent, Type[MessageComponent]]`):
要查找的消息元素或消息元素类型。
Returns:
int: 次数。
"""
if isinstance(x, type):
return sum(1 for i in self if type(i) is x)
if isinstance(x, MessageComponent):
return self.__root__.count(x)
raise TypeError(f"类型不匹配,当前类型:{type(x)}")
def extend(self, x: typing.Iterable[typing.Union[MessageComponent, str]]):
"""将另一个消息链中的元素添加到消息链末尾。
Args:
x: 另一个消息链,也可为消息元素或字符串元素的序列。
"""
self.__root__.extend(Plain(c) if isinstance(c, str) else c for c in x)
def append(self, x: typing.Union[MessageComponent, str]):
"""将一个消息元素或字符串元素添加到消息链末尾。
Args:
x: 消息元素或字符串元素。
"""
self.__root__.append(Plain(x) if isinstance(x, str) else x)
def insert(self, i: int, x: typing.Union[MessageComponent, str]):
"""将一个消息元素或字符串添加到消息链中指定位置。
Args:
i: 插入位置。
x: 消息元素或字符串元素。
"""
self.__root__.insert(i, Plain(x) if isinstance(x, str) else x)
def pop(self, i: int = -1) -> MessageComponent:
"""从消息链中移除并返回指定位置的元素。
Args:
i: 移除位置。默认为末尾。
Returns:
MessageComponent: 移除的元素。
"""
return self.__root__.pop(i)
def remove(self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]]):
"""从消息链中移除指定元素或指定类型的一个元素。
Args:
x: 指定的元素或元素类型。
"""
if isinstance(x, type):
self.pop(self.index(x))
if isinstance(x, MessageComponent):
self.__root__.remove(x)
def exclude(
self,
x: typing.Union[MessageComponent, typing.Type[MessageComponent]],
count: int = -1
) -> 'MessageChain':
"""返回移除指定元素或指定类型的元素后剩余的消息链。
Args:
x: 指定的元素或元素类型。
count: 至多移除的数量。默认为全部移除。
Returns:
MessageChain: 剩余的消息链。
"""
def _exclude():
nonlocal count
x_is_type = isinstance(x, type)
for c in self:
if count > 0 and ((x_is_type and type(c) is x) or c == x):
count -= 1
continue
yield c
return self.__class__(_exclude())
def reverse(self):
"""将消息链原地翻转。"""
self.__root__.reverse()
@classmethod
def join(cls, *args: typing.Iterable[typing.Union[str, MessageComponent]]):
return cls(
Plain(c) if isinstance(c, str) else c
for c in itertools.chain(*args)
)
@property
def source(self) -> typing.Optional['Source']:
"""获取消息链中的 `Source` 对象。"""
return self.get_first(Source)
@property
def message_id(self) -> int:
"""获取消息链的 message_id若无法获取返回 -1。"""
source = self.source
return source.id if source else -1
TMessage = typing.Union[MessageChain, typing.Iterable[typing.Union[MessageComponent, str]],
MessageComponent, str]
"""可以转化为 MessageChain 的类型。"""
class Source(MessageComponent):
"""源。包含消息的基本信息。"""
type: str = "Source"
"""消息组件类型。"""
id: int
"""消息的识别号用于引用回复Source 类型永远为 MessageChain 的第一个元素)。"""
time: datetime
"""消息时间。"""
class Plain(MessageComponent):
"""纯文本。"""
type: str = "Plain"
"""消息组件类型。"""
text: str
"""文字消息。"""
def __str__(self):
return self.text
def __repr__(self):
return f'Plain({self.text!r})'
class Quote(MessageComponent):
"""引用。"""
type: str = "Quote"
"""消息组件类型。"""
id: typing.Optional[int] = None
"""被引用回复的原消息的 message_id。"""
group_id: typing.Optional[int] = None
"""被引用回复的原消息所接收的群号当为好友消息时为0。"""
sender_id: typing.Optional[int] = None
"""被引用回复的原消息的发送者的QQ号。"""
target_id: typing.Optional[int] = None
"""被引用回复的原消息的接收者者的QQ号或群号"""
origin: MessageChain
"""被引用回复的原消息的消息链对象。"""
@pydantic.validator("origin", always=True, pre=True)
def origin_formater(cls, v):
return MessageChain.parse_obj(v)
class At(MessageComponent):
"""At某人。"""
type: str = "At"
"""消息组件类型。"""
target: int
"""群员 QQ 号。"""
display: typing.Optional[str] = None
"""At时显示的文字发送消息时无效自动使用群名片。"""
def __eq__(self, other):
return isinstance(other, At) and self.target == other.target
def __str__(self):
return f"@{self.display or self.target}"
class AtAll(MessageComponent):
"""At全体。"""
type: str = "AtAll"
"""消息组件类型。"""
def __str__(self):
return "@全体成员"
class Image(MessageComponent):
"""图片。"""
type: str = "Image"
"""消息组件类型。"""
image_id: typing.Optional[str] = None
"""图片的 image_id群图片与好友图片格式不同。不为空时将忽略 url 属性。"""
url: typing.Optional[pydantic.HttpUrl] = None
"""图片的 URL发送时可作网络图片的链接接收时为腾讯图片服务器的链接可用于图片下载。"""
path: typing.Union[str, Path, None] = None
"""图片的路径,发送本地图片。"""
base64: typing.Optional[str] = None
"""图片的 Base64 编码。"""
def __eq__(self, other):
return isinstance(
other, Image
) and self.type == other.type and self.uuid == other.uuid
def __str__(self):
return '[图片]'
@pydantic.validator('path')
def validate_path(cls, path: typing.Union[str, Path, None]):
"""修复 path 参数的行为,使之相对于 LangBot 的启动路径。"""
if path:
try:
return str(Path(path).resolve(strict=True))
except FileNotFoundError:
raise ValueError(f"无效路径:{path}")
else:
return path
@property
def uuid(self):
image_id = self.image_id
if image_id[0] == '{': # 群图片
image_id = image_id[1:37]
elif image_id[0] == '/': # 好友图片
image_id = image_id[1:]
return image_id
async def download(
self,
filename: typing.Union[str, Path, None] = None,
directory: typing.Union[str, Path, None] = None,
determine_type: bool = True
):
"""下载图片到本地。
Args:
filename: 下载到本地的文件路径。与 `directory` 二选一。
directory: 下载到本地的文件夹路径。与 `filename` 二选一。
determine_type: 是否自动根据图片类型确定拓展名,默认为 True。
"""
if not self.url:
logger.warning(f'图片 `{self.uuid}` 无 url 参数,下载失败。')
return
import httpx
async with httpx.AsyncClient() as client:
response = await client.get(self.url)
response.raise_for_status()
content = response.content
if filename:
path = Path(filename)
if determine_type:
import imghdr
path = path.with_suffix(
'.' + str(imghdr.what(None, content))
)
path.parent.mkdir(parents=True, exist_ok=True)
elif directory:
import imghdr
path = Path(directory)
path.mkdir(parents=True, exist_ok=True)
path = path / f'{self.uuid}.{imghdr.what(None, content)}'
else:
raise ValueError("请指定文件路径或文件夹路径!")
import aiofiles
async with aiofiles.open(path, 'wb') as f:
await f.write(content)
return path
@classmethod
async def from_local(
cls,
filename: typing.Union[str, Path, None] = None,
content: typing.Optional[bytes] = None,
) -> "Image":
"""从本地文件路径加载图片,以 base64 的形式传递。
Args:
filename: 从本地文件路径加载图片,与 `content` 二选一。
content: 从本地文件内容加载图片,与 `filename` 二选一。
Returns:
Image: 图片对象。
"""
if content:
pass
elif filename:
path = Path(filename)
import aiofiles
async with aiofiles.open(path, 'rb') as f:
content = await f.read()
else:
raise ValueError("请指定图片路径或图片内容!")
import base64
img = cls(base64=base64.b64encode(content).decode())
return img
@classmethod
def from_unsafe_path(cls, path: typing.Union[str, Path]) -> "Image":
"""从不安全的路径加载图片。
Args:
path: 从不安全的路径加载图片。
Returns:
Image: 图片对象。
"""
return cls.construct(path=str(path))
class Unknown(MessageComponent):
"""未知。"""
type: str = "Unknown"
"""消息组件类型。"""
text: str
"""文本。"""
class Voice(MessageComponent):
"""语音。"""
type: str = "Voice"
"""消息组件类型。"""
voice_id: typing.Optional[str] = None
"""语音的 voice_id不为空时将忽略 url 属性。"""
url: typing.Optional[str] = None
"""语音的 URL发送时可作网络语音的链接接收时为腾讯语音服务器的链接可用于语音下载。"""
path: typing.Optional[str] = None
"""语音的路径,发送本地语音。"""
base64: typing.Optional[str] = None
"""语音的 Base64 编码。"""
length: typing.Optional[int] = None
"""语音的长度,单位为秒。"""
@pydantic.validator('path')
def validate_path(cls, path: typing.Optional[str]):
"""修复 path 参数的行为,使之相对于 LangBot 的启动路径。"""
if path:
try:
return str(Path(path).resolve(strict=True))
except FileNotFoundError:
raise ValueError(f"无效路径:{path}")
else:
return path
def __str__(self):
return '[语音]'
async def download(
self,
filename: typing.Union[str, Path, None] = None,
directory: typing.Union[str, Path, None] = None
):
"""下载语音到本地。
语音采用 silk v3 格式silk 格式的编码解码请使用 [graiax-silkcoder](https://pypi.org/project/graiax-silkcoder/)。
Args:
filename: 下载到本地的文件路径。与 `directory` 二选一。
directory: 下载到本地的文件夹路径。与 `filename` 二选一。
"""
if not self.url:
logger.warning(f'语音 `{self.voice_id}` 无 url 参数,下载失败。')
return
import httpx
async with httpx.AsyncClient() as client:
response = await client.get(self.url)
response.raise_for_status()
content = response.content
if filename:
path = Path(filename)
path.parent.mkdir(parents=True, exist_ok=True)
elif directory:
path = Path(directory)
path.mkdir(parents=True, exist_ok=True)
path = path / f'{self.voice_id}.silk'
else:
raise ValueError("请指定文件路径或文件夹路径!")
import aiofiles
async with aiofiles.open(path, 'wb') as f:
await f.write(content)
@classmethod
async def from_local(
cls,
filename: typing.Union[str, Path, None] = None,
content: typing.Optional[bytes] = None,
) -> "Voice":
"""从本地文件路径加载语音,以 base64 的形式传递。
Args:
filename: 从本地文件路径加载语音,与 `content` 二选一。
content: 从本地文件内容加载语音,与 `filename` 二选一。
"""
if content:
pass
if filename:
path = Path(filename)
import aiofiles
async with aiofiles.open(path, 'rb') as f:
content = await f.read()
else:
raise ValueError("请指定语音路径或语音内容!")
import base64
img = cls(base64=base64.b64encode(content).decode())
return img
class ForwardMessageNode(pydantic.BaseModel):
"""合并转发中的一条消息。"""
sender_id: typing.Optional[int] = None
"""发送人QQ号。"""
sender_name: typing.Optional[str] = None
"""显示名称。"""
message_chain: typing.Optional[MessageChain] = None
"""消息内容。"""
message_id: typing.Optional[int] = None
"""消息的 message_id可以只使用此属性从缓存中读取消息内容。"""
time: typing.Optional[datetime] = None
"""发送时间。"""
@pydantic.validator('message_chain', check_fields=False)
def _validate_message_chain(cls, value: typing.Union[MessageChain, list]):
if isinstance(value, list):
return MessageChain.parse_obj(value)
return value
@classmethod
def create(
cls, sender: typing.Union[platform_entities.Friend, platform_entities.GroupMember], message: MessageChain
) -> 'ForwardMessageNode':
"""从消息链生成转发消息。
Args:
sender: 发送人。
message: 消息内容。
Returns:
ForwardMessageNode: 生成的一条消息。
"""
return ForwardMessageNode(
sender_id=sender.id,
sender_name=sender.get_name(),
message_chain=message
)
class Forward(MessageComponent):
"""合并转发。"""
type: str = "Forward"
"""消息组件类型。"""
node_list: typing.List[ForwardMessageNode]
"""转发消息节点列表。"""
def __init__(self, *args, **kwargs):
if len(args) == 1:
self.node_list = args[0]
super().__init__(**kwargs)
super().__init__(*args, **kwargs)
def __str__(self):
return '[聊天记录]'
class File(MessageComponent):
"""文件。"""
type: str = "File"
"""消息组件类型。"""
id: str
"""文件识别 ID。"""
name: str
"""文件名称。"""
size: int
"""文件大小。"""
def __str__(self):
return f'[文件]{self.name}'

View File

@@ -2,12 +2,13 @@ from __future__ import annotations
import typing import typing
import abc import abc
import pydantic import pydantic.v1 as pydantic
import mirai import enum
from . import events from . import events
from ..provider.tools import entities as tools_entities from ..provider.tools import entities as tools_entities
from ..core import app from ..core import app
from ..platform.types import message as platform_message
def register( def register(
@@ -85,15 +86,24 @@ class BasePlugin(metaclass=abc.ABCMeta):
"""应用程序对象""" """应用程序对象"""
def __init__(self, host: APIHost): def __init__(self, host: APIHost):
"""初始化阶段被调用"""
self.host = host self.host = host
async def initialize(self): async def initialize(self):
"""初始化插件""" """初始化阶段被调用"""
pass
async def destroy(self):
"""释放/禁用插件时被调用"""
pass
def __del__(self):
"""释放/禁用插件时被调用"""
pass pass
class APIHost: class APIHost:
"""QChatGPT API 宿主""" """LangBot API 宿主"""
ap: app.Application ap: app.Application
@@ -126,7 +136,7 @@ class APIHost:
if self.ap.ver_mgr.compare_version_str(qchatgpt_version, ge) < 0 or \ if self.ap.ver_mgr.compare_version_str(qchatgpt_version, ge) < 0 or \
(self.ap.ver_mgr.compare_version_str(qchatgpt_version, le) > 0): (self.ap.ver_mgr.compare_version_str(qchatgpt_version, le) > 0):
raise Exception("QChatGPT 版本不满足要求,某些功能(可能是由插件提供的)无法正常使用。(要求版本:{}-{},但当前版本:{}".format(ge, le, qchatgpt_version)) raise Exception("LangBot 版本不满足要求,某些功能(可能是由插件提供的)无法正常使用。(要求版本:{}-{},但当前版本:{}".format(ge, le, qchatgpt_version))
return True return True
@@ -174,11 +184,11 @@ class EventContext:
self.__return_value__[key] = [] self.__return_value__[key] = []
self.__return_value__[key].append(ret) self.__return_value__[key].append(ret)
async def reply(self, message_chain: mirai.MessageChain): async def reply(self, message_chain: platform_message.MessageChain):
"""回复此次消息请求 """回复此次消息请求
Args: Args:
message_chain (mirai.MessageChain): YiriMirai库的消息链,若用户使用的不是 YiriMirai 适配器,程序也能自动转换为目标消息链 message_chain (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链
""" """
await self.host.ap.platform_mgr.send( await self.host.ap.platform_mgr.send(
event=self.event.query.message_event, event=self.event.query.message_event,
@@ -190,14 +200,14 @@ class EventContext:
self, self,
target_type: str, target_type: str,
target_id: str, target_id: str,
message: mirai.MessageChain message: platform_message.MessageChain
): ):
"""主动发送消息 """主动发送消息
Args: Args:
target_type (str): 目标类型,`person`或`group` target_type (str): 目标类型,`person`或`group`
target_id (str): 目标ID target_id (str): 目标ID
message (mirai.MessageChain): YiriMirai库的消息链,若用户使用的不是 YiriMirai 适配器,程序也能自动转换为目标消息链 message (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链
""" """
await self.event.query.adapter.send_message( await self.event.query.adapter.send_message(
target_type=target_type, target_type=target_type,
@@ -247,6 +257,16 @@ class EventContext:
EventContext.eid += 1 EventContext.eid += 1
class RuntimeContainerStatus(enum.Enum):
"""插件容器状态"""
MOUNTED = "mounted"
"""已加载进内存,所有位于运行时记录中的 RuntimeContainer 至少是这个状态"""
INITIALIZED = "initialized"
"""已初始化"""
class RuntimeContainer(pydantic.BaseModel): class RuntimeContainer(pydantic.BaseModel):
"""运行时的插件容器 """运行时的插件容器
@@ -294,6 +314,9 @@ class RuntimeContainer(pydantic.BaseModel):
content_functions: list[tools_entities.LLMFunction] = [] content_functions: list[tools_entities.LLMFunction] = []
"""内容函数""" """内容函数"""
status: RuntimeContainerStatus = RuntimeContainerStatus.MOUNTED
"""插件状态"""
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@@ -318,5 +341,30 @@ class RuntimeContainer(pydantic.BaseModel):
self.priority = setting['priority'] self.priority = setting['priority']
self.enabled = setting['enabled'] self.enabled = setting['enabled']
for function in self.content_functions: def model_dump(self, *args, **kwargs):
function.enable = self.enabled return {
'name': self.plugin_name,
'description': self.plugin_description,
'version': self.plugin_version,
'author': self.plugin_author,
'source': self.plugin_source,
'main_file': self.main_file,
'pkg_path': self.pkg_path,
'enabled': self.enabled,
'priority': self.priority,
'event_handlers': {
event_name.__name__: handler.__name__
for event_name, handler in self.event_handlers.items()
},
'content_functions': [
{
'name': function.name,
'human_desc': function.human_desc,
'description': function.description,
'parameters': function.parameters,
'func': function.func.__name__,
}
for function in self.content_functions
],
'status': self.status.value,
}

View File

@@ -2,11 +2,11 @@ from __future__ import annotations
import typing import typing
import pydantic import pydantic.v1 as pydantic
import mirai
from ..core import entities as core_entities from ..core import entities as core_entities
from ..provider import entities as llm_entities from ..provider import entities as llm_entities
from ..platform.types import message as platform_message
class BaseEventModel(pydantic.BaseModel): class BaseEventModel(pydantic.BaseModel):
@@ -31,7 +31,7 @@ class PersonMessageReceived(BaseEventModel):
sender_id: int sender_id: int
"""发送者ID(QQ号)""" """发送者ID(QQ号)"""
message_chain: mirai.MessageChain message_chain: platform_message.MessageChain
class GroupMessageReceived(BaseEventModel): class GroupMessageReceived(BaseEventModel):
@@ -43,7 +43,7 @@ class GroupMessageReceived(BaseEventModel):
sender_id: int sender_id: int
message_chain: mirai.MessageChain message_chain: platform_message.MessageChain
class PersonNormalMessageReceived(BaseEventModel): class PersonNormalMessageReceived(BaseEventModel):

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import typing import typing
import abc import abc
from ..core import app from ..core import app, taskmgr
class PluginInstaller(metaclass=abc.ABCMeta): class PluginInstaller(metaclass=abc.ABCMeta):
@@ -21,6 +21,7 @@ class PluginInstaller(metaclass=abc.ABCMeta):
async def install_plugin( async def install_plugin(
self, self,
plugin_source: str, plugin_source: str,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
): ):
"""安装插件 """安装插件
""" """
@@ -30,6 +31,7 @@ class PluginInstaller(metaclass=abc.ABCMeta):
async def uninstall_plugin( async def uninstall_plugin(
self, self,
plugin_name: str, plugin_name: str,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
): ):
"""卸载插件 """卸载插件
""" """
@@ -40,6 +42,7 @@ class PluginInstaller(metaclass=abc.ABCMeta):
self, self,
plugin_name: str, plugin_name: str,
plugin_source: str=None, plugin_source: str=None,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
): ):
"""更新插件 """更新插件
""" """

View File

@@ -5,10 +5,14 @@ import os
import shutil import shutil
import zipfile import zipfile
import requests import aiohttp
import aiofiles
import aiofiles.os as aiofiles_os
import aioshutil
from .. import installer, errors from .. import installer, errors
from ...utils import pkgmgr from ...utils import pkgmgr
from ...core import taskmgr
class GitHubRepoInstaller(installer.PluginInstaller): class GitHubRepoInstaller(installer.PluginInstaller):
@@ -28,65 +32,65 @@ class GitHubRepoInstaller(installer.PluginInstaller):
return repo[0].split("/") return repo[0].split("/")
else: else:
return None return None
async def download_plugin_source_code(self, repo_url: str, target_path: str, task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder()) -> str:
"""下载插件源码(全异步)"""
async def download_plugin_source_code(self, repo_url: str, target_path: str) -> str:
"""下载插件源码"""
# 检查源类型
# 提取 username/repo , 正则表达式 # 提取 username/repo , 正则表达式
repo = self.get_github_plugin_repo_label(repo_url) repo = self.get_github_plugin_repo_label(repo_url)
target_path += repo[1] target_path += repo[1]
if repo is not None: # github if repo is None:
self.ap.logger.debug("正在下载源码...")
zipball_url = f"https://api.github.com/repos/{'/'.join(repo)}/zipball/HEAD"
zip_resp = requests.get(
url=zipball_url, proxies=self.ap.proxy_mgr.get_forward_proxies(), stream=True
)
if zip_resp.status_code != 200:
raise Exception("下载源码失败: {}".format(zip_resp.text))
if os.path.exists("temp/" + target_path):
shutil.rmtree("temp/" + target_path)
if os.path.exists(target_path):
shutil.rmtree(target_path)
os.makedirs("temp/" + target_path)
with open("temp/" + target_path + "/source.zip", "wb") as f:
for chunk in zip_resp.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
self.ap.logger.debug("解压中...")
with zipfile.ZipFile("temp/" + target_path + "/source.zip", "r") as zip_ref:
zip_ref.extractall("temp/" + target_path)
os.remove("temp/" + target_path + "/source.zip")
# 目标是 username-repo-hash , 用正则表达式提取完整的文件夹名,复制到 plugins/repo
import glob
# 获取解压后的文件夹名
unzip_dir = glob.glob("temp/" + target_path + "/*")[0]
# 复制到 plugins/repo
shutil.copytree(unzip_dir, target_path + "/")
# 删除解压后的文件夹
shutil.rmtree(unzip_dir)
self.ap.logger.debug("源码下载完成。")
else:
raise errors.PluginInstallerError('仅支持GitHub仓库地址') raise errors.PluginInstallerError('仅支持GitHub仓库地址')
self.ap.logger.debug("正在下载源码...")
task_context.trace("下载源码...", "download-plugin-source-code")
zipball_url = f"https://api.github.com/repos/{'/'.join(repo)}/zipball/HEAD"
zip_resp: bytes = None
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
url=zipball_url,
timeout=aiohttp.ClientTimeout(total=300)
) as resp:
if resp.status != 200:
raise errors.PluginInstallerError(f"下载源码失败: {resp.text}")
zip_resp = await resp.read()
if await aiofiles_os.path.exists("temp/" + target_path):
await aioshutil.rmtree("temp/" + target_path)
if await aiofiles_os.path.exists(target_path):
await aioshutil.rmtree(target_path)
await aiofiles_os.makedirs("temp/" + target_path)
async with aiofiles.open("temp/" + target_path + "/source.zip", "wb") as f:
await f.write(zip_resp)
self.ap.logger.debug("解压中...")
task_context.trace("解压中...", "unzip-plugin-source-code")
with zipfile.ZipFile("temp/" + target_path + "/source.zip", "r") as zip_ref:
zip_ref.extractall("temp/" + target_path)
await aiofiles_os.remove("temp/" + target_path + "/source.zip")
import glob
unzip_dir = glob.glob("temp/" + target_path + "/*")[0]
await aioshutil.copytree(unzip_dir, target_path + "/")
await aioshutil.rmtree(unzip_dir)
self.ap.logger.debug("源码下载完成。")
return repo[1] return repo[1]
async def install_requirements(self, path: str): async def install_requirements(self, path: str):
if os.path.exists(path + "/requirements.txt"): if os.path.exists(path + "/requirements.txt"):
pkgmgr.install_requirements(path + "/requirements.txt") pkgmgr.install_requirements(path + "/requirements.txt")
@@ -94,13 +98,20 @@ class GitHubRepoInstaller(installer.PluginInstaller):
async def install_plugin( async def install_plugin(
self, self,
plugin_source: str, plugin_source: str,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
): ):
"""安装插件 """安装插件
""" """
repo_label = await self.download_plugin_source_code(plugin_source, "plugins/") task_context.trace("下载插件源码...", "install-plugin")
repo_label = await self.download_plugin_source_code(plugin_source, "plugins/", task_context)
task_context.trace("安装插件依赖...", "install-plugin")
await self.install_requirements("plugins/" + repo_label) await self.install_requirements("plugins/" + repo_label)
task_context.trace("完成.", "install-plugin")
await self.ap.plugin_mgr.setting.record_installed_plugin_source( await self.ap.plugin_mgr.setting.record_installed_plugin_source(
"plugins/"+repo_label+'/', plugin_source "plugins/"+repo_label+'/', plugin_source
) )
@@ -108,6 +119,7 @@ class GitHubRepoInstaller(installer.PluginInstaller):
async def uninstall_plugin( async def uninstall_plugin(
self, self,
plugin_name: str, plugin_name: str,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
): ):
"""卸载插件 """卸载插件
""" """
@@ -116,15 +128,20 @@ class GitHubRepoInstaller(installer.PluginInstaller):
if plugin_container is None: if plugin_container is None:
raise errors.PluginInstallerError('插件不存在或未成功加载') raise errors.PluginInstallerError('插件不存在或未成功加载')
else: else:
shutil.rmtree(plugin_container.pkg_path) task_context.trace("删除插件目录...", "uninstall-plugin")
await aioshutil.rmtree(plugin_container.pkg_path)
task_context.trace("完成, 重新加载以生效.", "uninstall-plugin")
async def update_plugin( async def update_plugin(
self, self,
plugin_name: str, plugin_name: str,
plugin_source: str=None, plugin_source: str=None,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
): ):
"""更新插件 """更新插件
""" """
task_context.trace("更新插件...", "update-plugin")
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is None: if plugin_container is None:
@@ -133,7 +150,9 @@ class GitHubRepoInstaller(installer.PluginInstaller):
if plugin_container.plugin_source: if plugin_container.plugin_source:
plugin_source = plugin_container.plugin_source plugin_source = plugin_container.plugin_source
await self.install_plugin(plugin_source) task_context.trace("转交安装任务.", "update-plugin")
await self.install_plugin(plugin_source, task_context)
else: else:
raise errors.PluginInstallerError('插件无源码信息,无法更新') raise errors.PluginInstallerError('插件无源码信息,无法更新')

View File

@@ -13,13 +13,16 @@ class PluginLoader(metaclass=abc.ABCMeta):
ap: app.Application ap: app.Application
plugins: list[context.RuntimeContainer]
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
self.ap = ap self.ap = ap
self.plugins = []
async def initialize(self): async def initialize(self):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def load_plugins(self) -> list[context.RuntimeContainer]: async def load_plugins(self):
pass pass

View File

@@ -5,7 +5,7 @@ import pkgutil
import importlib import importlib
import traceback import traceback
from .. import loader, events, context, models, host from .. import loader, events, context, models
from ...core import entities as core_entities from ...core import entities as core_entities
from ...provider.tools import entities as tools_entities from ...provider.tools import entities as tools_entities
from ...utils import funcschema from ...utils import funcschema
@@ -20,7 +20,14 @@ class PluginLoader(loader.PluginLoader):
_current_container: context.RuntimeContainer = None _current_container: context.RuntimeContainer = None
containers: list[context.RuntimeContainer] = [] plugins: list[context.RuntimeContainer] = []
def __init__(self, ap):
self.ap = ap
self.plugins = []
self._current_pkg_path = ''
self._current_module_path = ''
self._current_container = None
async def initialize(self): async def initialize(self):
"""初始化""" """初始化"""
@@ -77,8 +84,10 @@ class PluginLoader(loader.PluginLoader):
} }
# 把 ctx.event 所有的属性都放到 args 里 # 把 ctx.event 所有的属性都放到 args 里
for k, v in ctx.event.dict().items(): # for k, v in ctx.event.dict().items():
args[k] = v # args[k] = v
for attr_name in ctx.event.__dict__.keys():
args[attr_name] = getattr(ctx.event, attr_name)
func(plugin, **args) func(plugin, **args)
@@ -113,7 +122,6 @@ class PluginLoader(loader.PluginLoader):
name=function_name, name=function_name,
human_desc='', human_desc='',
description=function_schema['description'], description=function_schema['description'],
enable=True,
parameters=function_schema['parameters'], parameters=function_schema['parameters'],
func=handler, func=handler,
) )
@@ -153,7 +161,6 @@ class PluginLoader(loader.PluginLoader):
name=function_name, name=function_name,
human_desc='', human_desc='',
description=function_schema['description'], description=function_schema['description'],
enable=True,
parameters=function_schema['parameters'], parameters=function_schema['parameters'],
func=func, func=func,
) )
@@ -189,15 +196,13 @@ class PluginLoader(loader.PluginLoader):
importlib.import_module(module.__name__ + "." + item.name) importlib.import_module(module.__name__ + "." + item.name)
if self._current_container is not None: if self._current_container is not None:
self.containers.append(self._current_container) self.plugins.append(self._current_container)
self.ap.logger.debug(f'插件 {self._current_container} 已加载') self.ap.logger.debug(f'插件 {self._current_container} 已加载')
except: except:
self.ap.logger.error(f'加载插件模块 {prefix + item.name} 时发生错误') self.ap.logger.error(f'加载插件模块 {prefix + item.name} 时发生错误')
traceback.print_exc() traceback.print_exc()
async def load_plugins(self) -> list[context.RuntimeContainer]: async def load_plugins(self):
"""加载插件 """加载插件
""" """
await self._walk_plugin_path(__import__("plugins", fromlist=[""])) await self._walk_plugin_path(__import__("plugins", fromlist=[""]))
return self.containers

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import typing import typing
import traceback import traceback
from ..core import app from ..core import app, taskmgr
from . import context, loader, events, installer, setting, models from . import context, loader, events, installer, setting, models
from .loaders import classic from .loaders import classic
from .installers import github from .installers import github
@@ -22,7 +22,22 @@ class PluginManager:
api_host: context.APIHost api_host: context.APIHost
plugins: list[context.RuntimeContainer] def plugins(
self,
enabled: bool=None,
status: context.RuntimeContainerStatus=None,
) -> list[context.RuntimeContainer]:
"""获取插件列表
"""
plugins = self.loader.plugins
if enabled is not None:
plugins = [plugin for plugin in plugins if plugin.enabled == enabled]
if status is not None:
plugins = [plugin for plugin in plugins if plugin.status == status]
return plugins
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
self.ap = ap self.ap = ap
@@ -30,7 +45,6 @@ class PluginManager:
self.installer = github.GitHubRepoInstaller(ap) self.installer = github.GitHubRepoInstaller(ap)
self.setting = setting.SettingManager(ap) self.setting = setting.SettingManager(ap)
self.api_host = context.APIHost(ap) self.api_host = context.APIHost(ap)
self.plugins = []
async def initialize(self): async def initialize(self):
await self.loader.initialize() await self.loader.initialize()
@@ -41,32 +55,66 @@ class PluginManager:
setattr(models, 'require_ver', self.api_host.require_ver) setattr(models, 'require_ver', self.api_host.require_ver)
async def load_plugins(self): async def load_plugins(self):
self.plugins = await self.loader.load_plugins() await self.loader.load_plugins()
await self.setting.sync_setting(self.plugins) await self.setting.sync_setting(self.loader.plugins)
# 按优先级倒序 # 按优先级倒序
self.plugins.sort(key=lambda x: x.priority, reverse=True) self.loader.plugins.sort(key=lambda x: x.priority, reverse=True)
self.ap.logger.debug(f'优先级排序后的插件列表 {self.loader.plugins}')
async def initialize_plugin(self, plugin: context.RuntimeContainer):
self.ap.logger.debug(f'初始化插件 {plugin.plugin_name}')
plugin.plugin_inst = plugin.plugin_class(self.api_host)
plugin.plugin_inst.ap = self.ap
plugin.plugin_inst.host = self.api_host
await plugin.plugin_inst.initialize()
plugin.status = context.RuntimeContainerStatus.INITIALIZED
async def initialize_plugins(self): async def initialize_plugins(self):
for plugin in self.plugins: for plugin in self.plugins():
if not plugin.enabled:
self.ap.logger.debug(f'插件 {plugin.plugin_name} 未启用,跳过初始化')
continue
try: try:
plugin.plugin_inst = plugin.plugin_class(self.api_host) await self.initialize_plugin(plugin)
plugin.plugin_inst.ap = self.ap
plugin.plugin_inst.host = self.api_host
await plugin.plugin_inst.initialize()
except Exception as e: except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}') self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}')
self.ap.logger.exception(e) self.ap.logger.exception(e)
continue continue
async def destroy_plugin(self, plugin: context.RuntimeContainer):
if plugin.status != context.RuntimeContainerStatus.INITIALIZED:
return
self.ap.logger.debug(f'释放插件 {plugin.plugin_name}')
plugin.plugin_inst.__del__()
await plugin.plugin_inst.destroy()
plugin.plugin_inst = None
plugin.status = context.RuntimeContainerStatus.MOUNTED
async def destroy_plugins(self):
for plugin in self.plugins():
if plugin.status != context.RuntimeContainerStatus.INITIALIZED:
self.ap.logger.debug(f'插件 {plugin.plugin_name} 未初始化,跳过释放')
continue
try:
await self.destroy_plugin(plugin)
except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 释放失败: {e}')
self.ap.logger.exception(e)
continue
async def install_plugin( async def install_plugin(
self, self,
plugin_source: str, plugin_source: str,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
): ):
"""安装插件 """安装插件
""" """
await self.installer.install_plugin(plugin_source) await self.installer.install_plugin(plugin_source, task_context)
await self.ap.ctr_mgr.plugin.post_install_record( await self.ap.ctr_mgr.plugin.post_install_record(
{ {
@@ -77,16 +125,25 @@ class PluginManager:
} }
) )
task_context.trace('重载插件..', 'reload-plugin')
await self.ap.reload(scope='plugin')
async def uninstall_plugin( async def uninstall_plugin(
self, self,
plugin_name: str, plugin_name: str,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
): ):
"""卸载插件 """卸载插件
""" """
await self.installer.uninstall_plugin(plugin_name)
plugin_container = self.get_plugin_by_name(plugin_name) plugin_container = self.get_plugin_by_name(plugin_name)
if plugin_container is None:
raise ValueError(f'插件 {plugin_name} 不存在')
await self.destroy_plugin(plugin_container)
await self.installer.uninstall_plugin(plugin_name, task_context)
await self.ap.ctr_mgr.plugin.post_remove_record( await self.ap.ctr_mgr.plugin.post_remove_record(
{ {
"name": plugin_name, "name": plugin_name,
@@ -96,14 +153,18 @@ class PluginManager:
} }
) )
task_context.trace('重载插件..', 'reload-plugin')
await self.ap.reload(scope='plugin')
async def update_plugin( async def update_plugin(
self, self,
plugin_name: str, plugin_name: str,
plugin_source: str=None, plugin_source: str=None,
task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder(),
): ):
"""更新插件 """更新插件
""" """
await self.installer.update_plugin(plugin_name, plugin_source) await self.installer.update_plugin(plugin_name, plugin_source, task_context)
plugin_container = self.get_plugin_by_name(plugin_name) plugin_container = self.get_plugin_by_name(plugin_name)
@@ -118,10 +179,13 @@ class PluginManager:
new_version="HEAD" new_version="HEAD"
) )
task_context.trace('重载插件..', 'reload-plugin')
await self.ap.reload(scope='plugin')
def get_plugin_by_name(self, plugin_name: str) -> context.RuntimeContainer: def get_plugin_by_name(self, plugin_name: str) -> context.RuntimeContainer:
"""通过插件名获取插件 """通过插件名获取插件
""" """
for plugin in self.plugins: for plugin in self.plugins():
if plugin.plugin_name == plugin_name: if plugin.plugin_name == plugin_name:
return plugin return plugin
return None return None
@@ -137,30 +201,32 @@ class PluginManager:
emitted_plugins: list[context.RuntimeContainer] = [] emitted_plugins: list[context.RuntimeContainer] = []
for plugin in self.plugins: for plugin in self.plugins(
if plugin.enabled: enabled=True,
if event.__class__ in plugin.event_handlers: status=context.RuntimeContainerStatus.INITIALIZED
self.ap.logger.debug(f'插件 {plugin.plugin_name} 触发事件 {event.__class__.__name__}') ):
if event.__class__ in plugin.event_handlers:
is_prevented_default_before_call = ctx.is_prevented_default() self.ap.logger.debug(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__}')
is_prevented_default_before_call = ctx.is_prevented_default()
try: try:
await plugin.event_handlers[event.__class__]( await plugin.event_handlers[event.__class__](
plugin.plugin_inst, plugin.plugin_inst,
ctx ctx
) )
except Exception as e: except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 触发事件 {event.__class__.__name__} 时发生错误: {e}') self.ap.logger.error(f'插件 {plugin.plugin_name} 处理事件 {event.__class__.__name__} 时发生错误: {e}')
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
emitted_plugins.append(plugin) emitted_plugins.append(plugin)
if not is_prevented_default_before_call and ctx.is_prevented_default(): if not is_prevented_default_before_call and ctx.is_prevented_default():
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行') self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行')
if ctx.is_prevented_postorder(): if ctx.is_prevented_postorder():
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行') self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行')
break break
for key in ctx.__return_value__.keys(): for key in ctx.__return_value__.keys():
if hasattr(ctx.event, key): if hasattr(ctx.event, key):
@@ -184,3 +250,41 @@ class PluginManager:
) )
return ctx return ctx
async def update_plugin_switch(self, plugin_name: str, new_status: bool):
if self.get_plugin_by_name(plugin_name) is not None:
for plugin in self.plugins():
if plugin.plugin_name == plugin_name:
if plugin.enabled == new_status:
return False
# 初始化/释放插件
if new_status:
await self.initialize_plugin(plugin)
else:
await self.destroy_plugin(plugin)
plugin.enabled = new_status
await self.setting.dump_container_setting(self.loader.plugins)
break
return True
else:
return False
async def reorder_plugins(self, plugins: list[dict]):
for plugin in plugins:
plugin_name = plugin.get('name')
plugin_priority = plugin.get('priority')
for plugin in self.loader.plugins:
if plugin.plugin_name == plugin_name:
plugin.priority = plugin_priority
break
self.loader.plugins.sort(key=lambda x: x.priority, reverse=True)
await self.setting.dump_container_setting(self.loader.plugins)

View File

@@ -45,6 +45,7 @@ class SettingManager:
for plugin_container in plugin_containers: for plugin_container in plugin_containers:
if plugin_container.plugin_name == value['name']: if plugin_container.plugin_name == value['name']:
plugin_container.set_from_setting_dict(value) plugin_container.set_from_setting_dict(value)
break
self.settings.data = { self.settings.data = {
'plugins': [ 'plugins': [

View File

@@ -2,9 +2,10 @@ from __future__ import annotations
import typing import typing
import enum import enum
import pydantic import pydantic.v1 as pydantic
import mirai
from ..platform.types import message as platform_message
class FunctionCall(pydantic.BaseModel): class FunctionCall(pydantic.BaseModel):
@@ -73,14 +74,14 @@ class Message(pydantic.BaseModel):
def readable_str(self) -> str: def readable_str(self) -> str:
if self.content is not None: if self.content is not None:
return str(self.role) + ": " + str(self.get_content_mirai_message_chain()) return str(self.role) + ": " + str(self.get_content_platform_message_chain())
elif self.tool_calls is not None: elif self.tool_calls is not None:
return f'调用工具: {self.tool_calls[0].id}' return f'调用工具: {self.tool_calls[0].id}'
else: else:
return '未知消息' return '未知消息'
def get_content_mirai_message_chain(self, prefix_text: str="") -> mirai.MessageChain | None: def get_content_platform_message_chain(self, prefix_text: str="") -> platform_message.MessageChain | None:
"""将内容转换为 Mirai MessageChain 对象 """将内容转换为平台消息 MessageChain 对象
Args: Args:
prefix_text (str): 首个文字组件的前缀文本 prefix_text (str): 首个文字组件的前缀文本
@@ -89,15 +90,15 @@ class Message(pydantic.BaseModel):
if self.content is None: if self.content is None:
return None return None
elif isinstance(self.content, str): elif isinstance(self.content, str):
return mirai.MessageChain([mirai.Plain(prefix_text+self.content)]) return platform_message.MessageChain([platform_message.Plain(prefix_text+self.content)])
elif isinstance(self.content, list): elif isinstance(self.content, list):
mc = [] mc = []
for ce in self.content: for ce in self.content:
if ce.type == 'text': if ce.type == 'text':
mc.append(mirai.Plain(ce.text)) mc.append(platform_message.Plain(ce.text))
elif ce.type == 'image_url': elif ce.type == 'image_url':
if ce.image_url.url.startswith("http"): if ce.image_url.url.startswith("http"):
mc.append(mirai.Image(url=ce.image_url.url)) mc.append(platform_message.Image(url=ce.image_url.url))
else: # base64 else: # base64
b64_str = ce.image_url.url b64_str = ce.image_url.url
@@ -105,15 +106,15 @@ class Message(pydantic.BaseModel):
if b64_str.startswith("data:"): if b64_str.startswith("data:"):
b64_str = b64_str.split(",")[1] b64_str = b64_str.split(",")[1]
mc.append(mirai.Image(base64=b64_str)) mc.append(platform_message.Image(base64=b64_str))
# 找第一个文字组件 # 找第一个文字组件
if prefix_text: if prefix_text:
for i, c in enumerate(mc): for i, c in enumerate(mc):
if isinstance(c, mirai.Plain): if isinstance(c, platform_message.Plain):
mc[i] = mirai.Plain(prefix_text+c.text) mc[i] = platform_message.Plain(prefix_text+c.text)
break break
else: else:
mc.insert(0, mirai.Plain(prefix_text)) mc.insert(0, platform_message.Plain(prefix_text))
return mirai.MessageChain(mc) return platform_message.MessageChain(mc)

View File

@@ -2,9 +2,9 @@ from __future__ import annotations
import typing import typing
import pydantic import pydantic.v1 as pydantic
from . import api from . import requester
from . import token from . import token
@@ -17,7 +17,7 @@ class LLMModelInfo(pydantic.BaseModel):
token_mgr: token.TokenManager token_mgr: token.TokenManager
requester: api.LLMAPIRequester requester: requester.LLMAPIRequester
tool_call_supported: typing.Optional[bool] = False tool_call_supported: typing.Optional[bool] = False

View File

@@ -2,11 +2,11 @@ from __future__ import annotations
import aiohttp import aiohttp
from . import entities from . import entities, requester
from ...core import app from ...core import app
from . import token, api from . import token
from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat from .requesters import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat, giteeaichatcmpl
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list" FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
@@ -18,7 +18,7 @@ class ModelManager:
model_list: list[entities.LLMModelInfo] model_list: list[entities.LLMModelInfo]
requesters: dict[str, api.LLMAPIRequester] requesters: dict[str, requester.LLMAPIRequester]
token_mgrs: dict[str, token.TokenManager] token_mgrs: dict[str, token.TokenManager]
@@ -42,7 +42,7 @@ class ModelManager:
for k, v in self.ap.provider_cfg.data['keys'].items(): for k, v in self.ap.provider_cfg.data['keys'].items():
self.token_mgrs[k] = token.TokenManager(k, v) self.token_mgrs[k] = token.TokenManager(k, v)
for api_cls in api.preregistered_requesters: for api_cls in requester.preregistered_requesters:
api_inst = api_cls(self.ap) api_inst = api_cls(self.ap)
await api_inst.initialize() await api_inst.initialize()
self.requesters[api_inst.name] = api_inst self.requesters[api_inst.name] = api_inst
@@ -94,7 +94,7 @@ class ModelManager:
model_name = model.get('model_name', default_model_info.model_name) model_name = model.get('model_name', default_model_info.model_name)
token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr
requester = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester req = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester
tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported) tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported)
vision_supported = model.get('vision_supported', default_model_info.vision_supported) vision_supported = model.get('vision_supported', default_model_info.vision_supported)
@@ -102,7 +102,7 @@ class ModelManager:
name=model['name'], name=model['name'],
model_name=model_name, model_name=model_name,
token_mgr=token_mgr, token_mgr=token_mgr,
requester=requester, requester=req,
tool_call_supported=tool_call_supported, tool_call_supported=tool_call_supported,
vision_supported=vision_supported vision_supported=vision_supported
) )

Some files were not shown because too many files have changed in this diff Show More