Compare commits

..

5 Commits

Author SHA1 Message Date
Junyan Qin
89b25b8985 chore: release v4.2.2 2025-08-29 17:01:26 +08:00
devin-ai-integration[bot]
46b4482a7d feat: add GitHub star count component to sidebar (#1636)
* feat: add GitHub star count component to sidebar

- Add GitHub star component to sidebar bottom section
- Fetch star count from space.langbot.app API
- Display star count with proper internationalization
- Open GitHub repository in new tab when clicked
- Follow existing sidebar styling patterns

Co-Authored-By: Rock <rockchinq@gmail.com>

* perf: ui

* chore: remove githubStars text

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: Rock <rockchinq@gmail.com>
2025-08-28 21:04:36 +08:00
Bruce
d9fa1cbb06 perf: add cmd enable config & fix announce request timeout & fix send card with disconnect ai platform (#1633)
* add cmd config && fix bugs

* perf: use `get`

* update bansess fix block match rule

* perf: comment for access-control session str

---------

Co-authored-by: Junyan Qin <rockchinq@gmail.com>
2025-08-28 12:59:50 +08:00
Bruce
8858f432b5 fix dingtalk message sender id & update dingtalk streaming card without content (#1630) 2025-08-27 18:09:30 +08:00
Junyan Qin
8f5ec48522 doc: update shengsuanyun comment 2025-08-26 16:00:48 +08:00
197 changed files with 6700 additions and 5633 deletions

View File

@@ -107,9 +107,9 @@ docker compose up -d
| [Anthropic](https://www.anthropic.com/) | ✅ | |
| [xAI](https://x.ai/) | ✅ | |
| [智谱AI](https://open.bigmodel.cn/) | ✅ | |
| [胜算云](https://www.shengsuanyun.com/?from=CH_KYIPP758) | ✅ | 全球大模型都可调用(友情推荐) |
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_langbot) | ✅ | 大模型和 GPU 资源平台 |
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | ✅ | 大模型和 GPU 资源平台 |
| [胜算云](https://www.shengsuanyun.com/?from=CH_KYIPP758) | ✅ | 大模型和 GPU 资源平台 |
| [302.AI](https://share.302.ai/SuTG99) | ✅ | 大模型聚合平台 |
| [Google Gemini](https://aistudio.google.com/prompts/new_chat) | ✅ | |
| [Dify](https://dify.ai) | ✅ | LLMOps 平台 |

View File

@@ -9,6 +9,7 @@ spec:
components:
ComponentTemplate:
fromFiles:
- pkg/platform/adapter.yaml
- pkg/provider/modelmgr/requester.yaml
MessagePlatformAdapter:
fromDirs:

View File

@@ -1,21 +1,6 @@
version: "3"
services:
langbot_plugin_runtime:
image: rockchin/langbot:latest
container_name: langbot_plugin_runtime
volumes:
- ./data/plugins:/app/data/plugins
ports:
- 5401:5401
restart: on-failure
environment:
- TZ=Asia/Shanghai
command: ["uv", "run", "-m", "langbot_plugin.cli.__init__", "rt"]
networks:
- langbot_network
langbot:
image: rockchin/langbot:latest
container_name: langbot
@@ -26,11 +11,6 @@ services:
environment:
- TZ=Asia/Shanghai
ports:
- 5300:5300 # For web ui
- 2280-2290:2280-2290 # For platform webhook
networks:
- langbot_network
networks:
langbot_network:
driver: bridge
- 5300:5300 # 供 WebUI 使用
- 2280-2290:2280-2290 # 供消息平台适配器方向连接
# 根据具体环境配置网络

View File

@@ -3,7 +3,7 @@ from quart import request
import httpx
from quart import Quart
from typing import Callable, Dict, Any
import langbot_plugin.api.entities.builtin.platform.events as platform_events
from pkg.platform.types import events as platform_events
from .qqofficialevent import QQOfficialEvent
import json
import traceback

View File

@@ -4,7 +4,7 @@ from quart import Quart, jsonify, request
from slack_sdk.web.async_client import AsyncWebClient
from .slackevent import SlackEvent
from typing import Callable
import langbot_plugin.api.entities.builtin.platform.events as platform_events
from pkg.platform.types import events as platform_events
class SlackClient:

View File

@@ -8,7 +8,7 @@ from quart import Quart
import xml.etree.ElementTree as ET
from typing import Callable, Dict, Any
from .wecomevent import WecomEvent
import langbot_plugin.api.entities.builtin.platform.message as platform_message
from pkg.platform.types import message as platform_message
import aiofiles

View File

@@ -8,7 +8,7 @@ from quart import Quart
import xml.etree.ElementTree as ET
from typing import Callable
from .wecomcsevent import WecomCSEvent
import langbot_plugin.api.entities.builtin.platform.message as platform_message
from pkg.platform.types import message as platform_message
import aiofiles

16
main.py
View File

@@ -19,14 +19,8 @@ asciiart = r"""
async def main_entry(loop: asyncio.AbstractEventLoop):
parser = argparse.ArgumentParser(description='LangBot')
parser.add_argument('--skip-plugin-deps-check', action='store_true', help='跳过插件依赖项检查', default=False)
parser.add_argument('--standalone-runtime', action='store_true', help='使用独立插件运行时', default=False)
args = parser.parse_args()
if args.standalone_runtime:
from pkg.utils import platform
platform.standalone_runtime = True
print(asciiart)
import sys
@@ -53,13 +47,13 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
if not args.skip_plugin_deps_check:
await deps.precheck_plugin_deps()
# # 检查pydantic版本如果没有 pydantic.v1则把 pydantic 映射为 v1
# import pydantic.version
# 检查pydantic版本如果没有 pydantic.v1则把 pydantic 映射为 v1
import pydantic.version
# if pydantic.version.VERSION < '2.0':
# import pydantic
if pydantic.version.VERSION < '2.0':
import pydantic
# sys.modules['pydantic.v1'] = pydantic
sys.modules['pydantic.v1'] = pydantic
# 检查配置文件

View File

@@ -44,9 +44,9 @@ class WebChatDebugRouterGroup(group.RouterGroup):
'Content-Type': 'text/event-stream',
'Transfer-Encoding': 'chunked',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'Connection': 'keep-alive'
}
return quart.Response(stream_generator(generator), mimetype='text/event-stream', headers=headers)
return quart.Response(stream_generator(generator), mimetype='text/event-stream',headers=headers)
else: # non-stream
result = None

View File

@@ -1,11 +1,10 @@
from __future__ import annotations
import base64
import quart
from .....core import taskmgr
from .. import group
from langbot_plugin.runtime.plugin.mgr import PluginInstallSource
@group.group_class('plugins', '/api/v1/plugins')
@@ -13,22 +12,35 @@ class PluginsRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
plugins = await self.ap.plugin_connector.list_plugins()
plugins = self.ap.plugin_mgr.plugins()
return self.success(data={'plugins': plugins})
plugins_data = [plugin.model_dump() for plugin in plugins]
return self.success(data={'plugins': plugins_data})
@self.route(
'/<author>/<plugin_name>/upgrade',
'/<author>/<plugin_name>/toggle',
methods=['PUT'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _(author: str, plugin_name: str) -> str:
data = await quart.request.json
target_enabled = data.get('target_enabled')
await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled)
return self.success()
@self.route(
'/<author>/<plugin_name>/update',
methods=['POST'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _(author: str, plugin_name: str) -> str:
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_connector.upgrade_plugin(author, plugin_name, task_context=ctx),
self.ap.plugin_mgr.update_plugin(plugin_name, task_context=ctx),
kind='plugin-operation',
name=f'plugin-upgrade-{plugin_name}',
label=f'Upgrading plugin {plugin_name}',
name=f'plugin-update-{plugin_name}',
label=f'Updating plugin {plugin_name}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
@@ -40,14 +52,14 @@ class PluginsRouterGroup(group.RouterGroup):
)
async def _(author: str, plugin_name: str) -> str:
if quart.request.method == 'GET':
plugin = await self.ap.plugin_connector.get_plugin_info(author, plugin_name)
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
if plugin is None:
return self.http_status(404, -1, 'plugin not found')
return self.success(data={'plugin': plugin})
return self.success(data={'plugin': plugin.model_dump()})
elif quart.request.method == 'DELETE':
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_connector.delete_plugin(author, plugin_name, task_context=ctx),
self.ap.plugin_mgr.uninstall_plugin(plugin_name, task_context=ctx),
kind='plugin-operation',
name=f'plugin-remove-{plugin_name}',
label=f'Removing plugin {plugin_name}',
@@ -62,19 +74,24 @@ class PluginsRouterGroup(group.RouterGroup):
auth_type=group.AuthType.USER_TOKEN,
)
async def _(author: str, plugin_name: str) -> quart.Response:
plugin = await self.ap.plugin_connector.get_plugin_info(author, plugin_name)
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
if plugin is None:
return self.http_status(404, -1, 'plugin not found')
if quart.request.method == 'GET':
return self.success(data={'config': plugin['plugin_config']})
return self.success(data={'config': plugin.plugin_config})
elif quart.request.method == 'PUT':
data = await quart.request.json
await self.ap.plugin_connector.set_plugin_config(author, plugin_name, data)
await self.ap.plugin_mgr.set_plugin_config(plugin, data)
return self.success(data={})
@self.route('/reorder', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
data = await quart.request.json
await self.ap.plugin_mgr.reorder_plugins(data.get('plugins'))
return self.success()
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
data = await quart.request.json
@@ -85,47 +102,7 @@ class PluginsRouterGroup(group.RouterGroup):
self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx),
kind='plugin-operation',
name='plugin-install-github',
label=f'Installing plugin from github ...{short_source_str}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
@self.route('/install/marketplace', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
data = await quart.request.json
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_connector.install_plugin(PluginInstallSource.MARKETPLACE, data, task_context=ctx),
kind='plugin-operation',
name='plugin-install-marketplace',
label=f'Installing plugin from marketplace ...{data}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
@self.route('/install/local', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
file = (await quart.request.files).get('file')
if file is None:
return self.http_status(400, -1, 'file is required')
file_bytes = file.read()
file_base64 = base64.b64encode(file_bytes).decode('utf-8')
data = {
'plugin_file': file_base64,
}
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_connector.install_plugin(PluginInstallSource.LOCAL, data, task_context=ctx),
kind='plugin-operation',
name='plugin-install-local',
label=f'Installing plugin from local ...{file.filename}',
label=f'Installing plugin ...{short_source_str}',
context=ctx,
)

View File

@@ -14,11 +14,6 @@ class SystemRouterGroup(group.RouterGroup):
'version': constants.semantic_version,
'debug': constants.debug_mode,
'enabled_platform_count': len(self.ap.platform_mgr.get_running_adapters()),
'cloud_service_url': (
self.ap.instance_config.data['plugin']['cloud_service_url']
if 'cloud_service_url' in self.ap.instance_config.data['plugin']
else 'https://space.langbot.app'
),
}
)
@@ -40,7 +35,16 @@ class SystemRouterGroup(group.RouterGroup):
return self.success(data=task.to_dict())
@self.route('/debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
@self.route('/reload', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
json_data = await quart.request.json
scope = json_data.get('scope')
await self.ap.reload(scope=scope)
return self.success()
@self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, 'Forbidden')
@@ -50,39 +54,3 @@ class SystemRouterGroup(group.RouterGroup):
ap = self.ap
return self.success(data=exec(py_code, {'ap': ap}))
@self.route('/debug/tools/call', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, 'Forbidden')
data = await quart.request.json
return self.success(
data=await self.ap.tool_mgr.execute_func_call(data['tool_name'], data['tool_parameters'])
)
@self.route(
'/debug/plugin/action',
methods=['POST'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, 'Forbidden')
data = await quart.request.json
class AnoymousAction:
value = 'anonymous_action'
def __init__(self, value: str):
self.value = value
resp = await self.ap.plugin_connector.handler.call_action(
AnoymousAction(data['action']),
data['data'],
timeout=data.get('timeout', 10),
)
return self.success(data=resp)

View File

@@ -71,15 +71,15 @@ class UserRouterGroup(group.RouterGroup):
@self.route('/change-password', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _(user_email: str) -> str:
json_data = await quart.request.json
current_password = json_data['current_password']
new_password = json_data['new_password']
try:
await self.ap.user_service.change_password(user_email, current_password, new_password)
except argon2.exceptions.VerifyMismatchError:
return self.http_status(400, -1, 'Current password is incorrect')
except ValueError as e:
return self.http_status(400, -1, str(e))
return self.success(data={'user': user_email})

View File

@@ -17,20 +17,16 @@ class BotService:
def __init__(self, ap: app.Application) -> None:
self.ap = ap
async def get_bots(self, include_secret: bool = True) -> list[dict]:
"""获取所有机器人"""
async def get_bots(self) -> list[dict]:
"""Get all bots"""
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot))
bots = result.all()
masked_columns = []
if not include_secret:
masked_columns = ['adapter_config']
return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot) for bot in bots]
return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot, masked_columns) for bot in bots]
async def get_bot(self, bot_uuid: str, include_secret: bool = True) -> dict | None:
"""获取机器人"""
async def get_bot(self, bot_uuid: str) -> dict | None:
"""Get bot"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
)
@@ -40,27 +36,7 @@ class BotService:
if bot is None:
return None
masked_columns = []
if not include_secret:
masked_columns = ['adapter_config']
return self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot, masked_columns)
async def get_runtime_bot_info(self, bot_uuid: str, include_secret: bool = True) -> dict:
"""获取机器人运行时信息"""
persistence_bot = await self.get_bot(bot_uuid, include_secret)
if persistence_bot is None:
raise Exception('Bot not found')
adapter_runtime_values = {}
runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
if runtime_bot is not None:
adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id
persistence_bot['adapter_runtime_values'] = adapter_runtime_values
return persistence_bot
return self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot)
async def create_bot(self, bot_data: dict) -> str:
"""Create bot"""

View File

@@ -7,7 +7,7 @@ from ....core import app
from ....entity.persistence import model as persistence_model
from ....entity.persistence import pipeline as persistence_pipeline
from ....provider.modelmgr import requester as model_requester
from langbot_plugin.api.entities.builtin.provider import message as provider_message
from ....provider import entities as llm_entities
class LLMModelsService:
@@ -16,19 +16,11 @@ class LLMModelsService:
def __init__(self, ap: app.Application) -> None:
self.ap = ap
async def get_llm_models(self, include_secret: bool = True) -> list[dict]:
async def get_llm_models(self) -> list[dict]:
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
models = result.all()
masked_columns = []
if not include_secret:
masked_columns = ['api_keys']
return [
self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model, masked_columns)
for model in models
]
return [self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) for model in models]
async def create_llm_model(self, model_data: dict) -> str:
model_data['uuid'] = str(uuid.uuid4())
@@ -107,7 +99,7 @@ class LLMModelsService:
await runtime_llm_model.requester.invoke_llm(
query=None,
model=runtime_llm_model,
messages=[provider_message.Message(role='user', content='Hello, world!')],
messages=[llm_entities.Message(role='user', content='Hello, world!')],
funcs=[],
extra_args=model_data.get('extra_args', {}),
)

View File

@@ -85,15 +85,15 @@ class UserService:
async def change_password(self, user_email: str, current_password: str, new_password: str) -> None:
ph = argon2.PasswordHasher()
user_obj = await self.get_user_by_email(user_email)
if user_obj is None:
raise ValueError('User not found')
ph.verify(user_obj.password, current_password)
hashed_password = ph.hash(new_password)
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(user.User).where(user.User.user == user_email).values(password=hashed_password)
)

View File

@@ -2,12 +2,9 @@ from __future__ import annotations
import typing
from ..core import app
from . import operator
from ..core import app, entities as core_entities
from . import entities, operator, errors
from ..utils import importutil
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
# 引入所有算子以便注册
from . import operators
@@ -16,11 +13,13 @@ importutil.import_modules_in_pkg(operators)
class CommandManager:
"""命令管理器"""
ap: app.Application
cmd_list: list[operator.CommandOperator]
"""
Runtime command list, flat storage, each object contains a reference to the corresponding child node
运行时命令列表,扁平存储,各个对象包含对应的子节点引用
"""
def __init__(self, ap: app.Application):
@@ -56,28 +55,43 @@ class CommandManager:
async def _execute(
self,
context: command_context.ExecuteContext,
context: entities.ExecuteContext,
operator_list: list[operator.CommandOperator],
operator: operator.CommandOperator = None,
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令"""
command_list = await self.ap.plugin_connector.list_commands()
found = False
if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
for oper in operator_list:
if (context.crt_params[0] == oper.name or context.crt_params[0] in oper.alias) and (
oper.parent_class is None or oper.parent_class == operator.__class__
):
found = True
for command in command_list:
if command.metadata.name == context.command:
async for ret in self.ap.plugin_connector.execute_command(context):
yield ret
break
else:
yield command_context.CommandReturn(error=command_errors.CommandNotFoundError(context.command))
context.crt_command = context.crt_params[0]
context.crt_params = context.crt_params[1:]
async for ret in self._execute(context, oper.children, oper):
yield ret
break
if not found: # 如果下一个参数未在此节点的子节点中找到,则执行此节点或者报错
if operator is None:
yield entities.CommandReturn(error=errors.CommandNotFoundError(context.crt_params[0]))
else:
if operator.lowest_privilege > context.privilege:
yield entities.CommandReturn(error=errors.CommandPrivilegeError(operator.name))
else:
async for ret in operator.execute(context):
yield ret
async def execute(
self,
command_text: str,
query: pipeline_query.Query,
session: provider_session.Session,
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
query: core_entities.Query,
session: core_entities.Session,
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令"""
privilege = 1
@@ -85,8 +99,8 @@ class CommandManager:
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
privilege = 2
ctx = command_context.ExecuteContext(
query_id=query.query_id,
ctx = entities.ExecuteContext(
query=query,
session=session,
command_text=command_text,
command='',
@@ -96,9 +110,5 @@ class CommandManager:
privilege=privilege,
)
ctx.command = ctx.params[0]
ctx.shift()
async for ret in self._execute(ctx, self.cmd_list):
yield ret

74
pkg/command/entities.py Normal file
View File

@@ -0,0 +1,74 @@
from __future__ import annotations
import typing
import pydantic.v1 as pydantic
from ..core import entities as core_entities
from . import errors
from ..platform.types import message as platform_message
class CommandReturn(pydantic.BaseModel):
"""命令返回值"""
text: typing.Optional[str] = None
"""文本
"""
image: typing.Optional[platform_message.Image] = None
"""弃用"""
image_url: typing.Optional[str] = None
"""图片链接
"""
error: typing.Optional[errors.CommandError] = None
"""错误
"""
class Config:
arbitrary_types_allowed = True
class ExecuteContext(pydantic.BaseModel):
"""单次命令执行上下文"""
query: core_entities.Query
"""本次消息的请求对象"""
session: core_entities.Session
"""本次消息所属的会话对象"""
command_text: str
"""命令完整文本"""
command: str
"""命令名称"""
crt_command: str
"""当前命令
多级命令中crt_command为当前命令command为根命令。
例如:!plugin on Webwlkr
处理到plugin时command为plugincrt_command为plugin
处理到on时command为plugincrt_command为on
"""
params: list[str]
"""命令参数
整个命令以空格分割后的参数列表
"""
crt_params: list[str]
"""当前命令参数
多级命令中crt_params为当前命令参数params为根命令参数。
例如:!plugin on Webwlkr
处理到plugin时params为['on', 'Webwlkr']crt_params为['on', 'Webwlkr']
处理到on时params为['on', 'Webwlkr']crt_params为['Webwlkr']
"""
privilege: int
"""发起人权限"""

26
pkg/command/errors.py Normal file
View File

@@ -0,0 +1,26 @@
class CommandError(Exception):
def __init__(self, message: str = None):
self.message = message
def __str__(self):
return self.message
class CommandNotFoundError(CommandError):
def __init__(self, message: str = None):
super().__init__('未知命令: ' + message)
class CommandPrivilegeError(CommandError):
def __init__(self, message: str = None):
super().__init__('权限不足: ' + message)
class ParamNotEnoughError(CommandError):
def __init__(self, message: str = None):
super().__init__('参数不足: ' + message)
class CommandOperationError(CommandError):
def __init__(self, message: str = None):
super().__init__('操作失败: ' + message)

View File

@@ -4,7 +4,7 @@ import typing
import abc
from ..core import app
from langbot_plugin.api.entities.builtin.command import context as command_context
from . import entities
preregistered_operators: list[typing.Type[CommandOperator]] = []
@@ -95,18 +95,16 @@ class CommandOperator(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""实现此方法以执行命令
支持多次yield以返回多个结果。
例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。
Args:
context (command_context.ExecuteContext): 命令执行上下文
context (entities.ExecuteContext): 命令执行上下文
Yields:
command_context.CommandReturn: 命令返回封装
entities.CommandReturn: 命令返回封装
"""
pass

View File

@@ -2,17 +2,14 @@ from __future__ import annotations
import typing
from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
from .. import operator, entities, errors
@operator.operator_class(name='cmd', help='显示命令列表', usage='!cmd\n!cmd <命令名称>')
class CmdOperator(operator.CommandOperator):
"""命令列表"""
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行"""
if len(context.crt_params) == 0:
reply_str = '当前所有命令: \n\n'
@@ -23,7 +20,7 @@ class CmdOperator(operator.CommandOperator):
reply_str += '\n使用 !cmd <命令名称> 查看命令的详细帮助'
yield command_context.CommandReturn(text=reply_str.strip())
yield entities.CommandReturn(text=reply_str.strip())
else:
cmd_name = context.crt_params[0]
@@ -36,9 +33,9 @@ class CmdOperator(operator.CommandOperator):
break
if cmd is None:
yield command_context.CommandReturn(error=command_errors.CommandNotFoundError(cmd_name))
yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name))
else:
reply_str = f'{cmd.name}: {cmd.help}\n\n'
reply_str += f'使用方法: \n{cmd.usage}'
yield command_context.CommandReturn(text=reply_str.strip())
yield entities.CommandReturn(text=reply_str.strip())

View File

@@ -2,26 +2,23 @@ from __future__ import annotations
import typing
from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
from .. import operator, entities, errors
@operator.operator_class(name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all')
class DelOperator(operator.CommandOperator):
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
delete_index = 0
if len(context.crt_params) > 0:
try:
delete_index = int(context.crt_params[0])
except Exception:
yield command_context.CommandReturn(error=command_errors.CommandOperationError('索引必须是整数'))
yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数'))
return
if delete_index < 0 or delete_index >= len(context.session.conversations):
yield command_context.CommandReturn(error=command_errors.CommandOperationError('索引超出范围'))
yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围'))
return
# 倒序
@@ -32,17 +29,15 @@ class DelOperator(operator.CommandOperator):
del context.session.conversations[to_delete_index]
yield command_context.CommandReturn(text=f'已删除对话: {delete_index}')
yield entities.CommandReturn(text=f'已删除对话: {delete_index}')
else:
yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
@operator.operator_class(name='all', help='删除此会话的所有历史记录', parent_class=DelOperator)
class DelAllOperator(operator.CommandOperator):
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
context.session.conversations = []
context.session.using_conversation = None
yield command_context.CommandReturn(text='已删除所有对话')
yield entities.CommandReturn(text='已删除所有对话')

View File

@@ -1,20 +1,19 @@
from __future__ import annotations
from typing import AsyncGenerator
from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context
from .. import operator, entities
@operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func')
class FuncOperator(operator.CommandOperator):
async def execute(
self, context: command_context.ExecuteContext
) -> AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> AsyncGenerator[entities.CommandReturn, None]:
reply_str = '当前已启用的内容函数: \n\n'
index = 1
all_functions = await self.ap.tool_mgr.get_all_tools()
all_functions = await self.ap.tool_mgr.get_all_functions(
plugin_enabled=True,
)
for func in all_functions:
reply_str += '{}. {}:\n{}\n\n'.format(
@@ -24,4 +23,4 @@ class FuncOperator(operator.CommandOperator):
)
index += 1
yield command_context.CommandReturn(text=reply_str)
yield entities.CommandReturn(text=reply_str)

View File

@@ -2,17 +2,14 @@ from __future__ import annotations
import typing
from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context
from .. import operator, entities
@operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>')
class HelpOperator(operator.CommandOperator):
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接https://langbot.app'
help += '\n发送命令 !cmd 可查看命令列表'
yield command_context.CommandReturn(text=help)
yield entities.CommandReturn(text=help)

View File

@@ -3,31 +3,26 @@ from __future__ import annotations
import typing
from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
from .. import operator, entities, errors
@operator.operator_class(name='last', help='切换到前一个对话', usage='!last')
class LastOperator(operator.CommandOperator):
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
# 找到当前会话的上一个会话
for index in range(len(context.session.conversations) - 1, -1, -1):
if context.session.conversations[index] == context.session.using_conversation:
if index == 0:
yield command_context.CommandReturn(
error=command_errors.CommandOperationError('已经是第一个对话了')
)
yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了'))
return
else:
context.session.using_conversation = context.session.conversations[index - 1]
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
yield command_context.CommandReturn(
yield entities.CommandReturn(
text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}'
)
return
else:
yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))

View File

@@ -2,22 +2,19 @@ from __future__ import annotations
import typing
from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
from .. import operator, entities, errors
@operator.operator_class(name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>')
class ListOperator(operator.CommandOperator):
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
page = 0
if len(context.crt_params) > 0:
try:
page = int(context.crt_params[0] - 1)
except Exception:
yield command_context.CommandReturn(error=command_errors.CommandOperationError('页码应为整数'))
yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数'))
return
record_per_page = 10
@@ -48,4 +45,4 @@ class ListOperator(operator.CommandOperator):
else:
content += f'\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else "无内容"}'
yield command_context.CommandReturn(text=f'{page + 1} 页 (时间倒序):\n{content}')
yield entities.CommandReturn(text=f'{page + 1} 页 (时间倒序):\n{content}')

View File

@@ -2,31 +2,26 @@ from __future__ import annotations
import typing
from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
from .. import operator, entities, errors
@operator.operator_class(name='next', help='切换到后一个对话', usage='!next')
class NextOperator(operator.CommandOperator):
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
# 找到当前会话的下一个会话
for index in range(len(context.session.conversations)):
if context.session.conversations[index] == context.session.using_conversation:
if index == len(context.session.conversations) - 1:
yield command_context.CommandReturn(
error=command_errors.CommandOperationError('已经是最后一个对话了')
)
yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了'))
return
else:
context.session.using_conversation = context.session.conversations[index + 1]
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
yield command_context.CommandReturn(
yield entities.CommandReturn(
text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}'
)
return
else:
yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))

View File

@@ -2,8 +2,7 @@ from __future__ import annotations
import typing
import traceback
from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
from .. import operator, entities, errors
@operator.operator_class(
@@ -12,9 +11,7 @@ from langbot_plugin.api.entities.builtin.command import context as command_conte
usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>',
)
class PluginOperator(operator.CommandOperator):
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
plugin_list = self.ap.plugin_mgr.plugins()
reply_str = '所有插件({}):\n'.format(len(plugin_list))
idx = 0
@@ -30,36 +27,32 @@ class PluginOperator(operator.CommandOperator):
idx += 1
yield command_context.CommandReturn(text=reply_str)
yield entities.CommandReturn(text=reply_str)
@operator.operator_class(name='get', help='安装插件', privilege=2, parent_class=PluginOperator)
class PluginGetOperator(operator.CommandOperator):
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件仓库地址'))
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址'))
else:
repo = context.crt_params[0]
yield command_context.CommandReturn(text='正在安装插件...')
yield entities.CommandReturn(text='正在安装插件...')
try:
await self.ap.plugin_mgr.install_plugin(repo)
yield command_context.CommandReturn(text='插件安装成功,请重启程序以加载插件')
yield entities.CommandReturn(text='插件安装成功,请重启程序以加载插件')
except Exception as e:
traceback.print_exc()
yield command_context.CommandReturn(error=command_errors.CommandError('插件安装失败: ' + str(e)))
yield entities.CommandReturn(error=errors.CommandError('插件安装失败: ' + str(e)))
@operator.operator_class(name='update', help='更新插件', privilege=2, parent_class=PluginOperator)
class PluginUpdateOperator(operator.CommandOperator):
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
@@ -67,26 +60,24 @@ class PluginUpdateOperator(operator.CommandOperator):
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is not None:
yield command_context.CommandReturn(text='正在更新插件...')
yield entities.CommandReturn(text='正在更新插件...')
await self.ap.plugin_mgr.update_plugin(plugin_name)
yield command_context.CommandReturn(text='插件更新成功,请重启程序以加载插件')
yield entities.CommandReturn(text='插件更新成功,请重启程序以加载插件')
else:
yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: 未找到插件'))
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: 未找到插件'))
except Exception as e:
traceback.print_exc()
yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e)))
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
@operator.operator_class(name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator)
class PluginUpdateAllOperator(operator.CommandOperator):
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()]
if plugins:
yield command_context.CommandReturn(text='正在更新插件...')
yield entities.CommandReturn(text='正在更新插件...')
updated = []
try:
for plugin_name in plugins:
@@ -94,22 +85,20 @@ class PluginUpdateAllOperator(operator.CommandOperator):
updated.append(plugin_name)
except Exception as e:
traceback.print_exc()
yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e)))
yield command_context.CommandReturn(text='已更新插件: {}'.format(', '.join(updated)))
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
yield entities.CommandReturn(text='已更新插件: {}'.format(', '.join(updated)))
else:
yield command_context.CommandReturn(text='没有可更新的插件')
yield entities.CommandReturn(text='没有可更新的插件')
except Exception as e:
traceback.print_exc()
yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e)))
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
@operator.operator_class(name='del', help='删除插件', privilege=2, parent_class=PluginOperator)
class PluginDelOperator(operator.CommandOperator):
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
@@ -117,55 +106,51 @@ class PluginDelOperator(operator.CommandOperator):
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is not None:
yield command_context.CommandReturn(text='正在删除插件...')
yield entities.CommandReturn(text='正在删除插件...')
await self.ap.plugin_mgr.uninstall_plugin(plugin_name)
yield command_context.CommandReturn(text='插件删除成功,请重启程序以加载插件')
yield entities.CommandReturn(text='插件删除成功,请重启程序以加载插件')
else:
yield command_context.CommandReturn(error=command_errors.CommandError('插件删除失败: 未找到插件'))
yield entities.CommandReturn(error=errors.CommandError('插件删除失败: 未找到插件'))
except Exception as e:
traceback.print_exc()
yield command_context.CommandReturn(error=command_errors.CommandError('插件删除失败: ' + str(e)))
yield entities.CommandReturn(error=errors.CommandError('插件删除失败: ' + str(e)))
@operator.operator_class(name='on', help='启用插件', privilege=2, parent_class=PluginOperator)
class PluginEnableOperator(operator.CommandOperator):
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
try:
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True):
yield command_context.CommandReturn(text='已启用插件: {}'.format(plugin_name))
yield entities.CommandReturn(text='已启用插件: {}'.format(plugin_name))
else:
yield command_context.CommandReturn(
error=command_errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
yield entities.CommandReturn(
error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
)
except Exception as e:
traceback.print_exc()
yield command_context.CommandReturn(error=command_errors.CommandError('插件状态修改失败: ' + str(e)))
yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e)))
@operator.operator_class(name='off', help='禁用插件', privilege=2, parent_class=PluginOperator)
class PluginDisableOperator(operator.CommandOperator):
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
try:
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False):
yield command_context.CommandReturn(text='已禁用插件: {}'.format(plugin_name))
yield entities.CommandReturn(text='已禁用插件: {}'.format(plugin_name))
else:
yield command_context.CommandReturn(
error=command_errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
yield entities.CommandReturn(
error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
)
except Exception as e:
traceback.print_exc()
yield command_context.CommandReturn(error=command_errors.CommandError('插件状态修改失败: ' + str(e)))
yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e)))

View File

@@ -2,22 +2,19 @@ from __future__ import annotations
import typing
from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
from .. import operator, entities, errors
@operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt')
class PromptOperator(operator.CommandOperator):
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行"""
if context.session.using_conversation is None:
yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
else:
reply_str = '当前对话所有内容:\n\n'
for msg in context.session.using_conversation.messages:
reply_str += f'{msg.role}: {msg.content}\n'
yield command_context.CommandReturn(text=reply_str)
yield entities.CommandReturn(text=reply_str)

View File

@@ -2,18 +2,15 @@ from __future__ import annotations
import typing
from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
from .. import operator, entities, errors
@operator.operator_class(name='resend', help='重发当前会话的最后一条消息', usage='!resend')
class ResendOperator(operator.CommandOperator):
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
# 回滚到最后一条用户message前
if context.session.using_conversation is None:
yield command_context.CommandReturn(error=command_errors.CommandError('当前没有对话'))
yield entities.CommandReturn(error=errors.CommandError('当前没有对话'))
else:
conv_msg = context.session.using_conversation.messages
@@ -26,4 +23,4 @@ class ResendOperator(operator.CommandOperator):
conv_msg.pop()
# 不重发了,提示用户已删除就行了
yield command_context.CommandReturn(text='已删除最后一次请求记录')
yield entities.CommandReturn(text='已删除最后一次请求记录')

View File

@@ -2,16 +2,13 @@ from __future__ import annotations
import typing
from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context
from .. import operator, entities
@operator.operator_class(name='reset', help='重置当前会话', usage='!reset')
class ResetOperator(operator.CommandOperator):
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行"""
context.session.using_conversation = None
yield command_context.CommandReturn(text='已重置当前会话')
yield entities.CommandReturn(text='已重置当前会话')

View File

@@ -0,0 +1,11 @@
from __future__ import annotations
import typing
from .. import operator, entities
@operator.operator_class(name='update', help='更新程序', usage='!update', privilege=2)
class UpdateCommand(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
yield entities.CommandReturn(text='不再支持通过命令更新,请查看 LangBot 文档。')

View File

@@ -2,15 +2,12 @@ from __future__ import annotations
import typing
from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context
from .. import operator, entities
@operator.operator_class(name='version', help='显示版本信息', usage='!version')
class VersionCommand(operator.CommandOperator):
async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}'
try:
@@ -19,4 +16,4 @@ class VersionCommand(operator.CommandOperator):
except Exception:
pass
yield command_context.CommandReturn(text=reply_str.strip())
yield entities.CommandReturn(text=reply_str.strip())

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import logging
import asyncio
import traceback
import sys
import os
from ..platform import botmgr as im_mgr
@@ -11,7 +12,7 @@ from ..provider.modelmgr import modelmgr as llm_model_mgr
from ..provider.tools import toolmgr as llm_tool_mgr
from ..config import manager as config_mgr
from ..command import cmdmgr
from ..plugin import connector as plugin_connector
from ..plugin import manager as plugin_mgr
from ..pipeline import pool
from ..pipeline import controller, pipelinemgr
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
@@ -79,7 +80,7 @@ class Application:
# =========================
plugin_connector: plugin_connector.PluginRuntimeConnector = None
plugin_mgr: plugin_mgr.PluginManager = None
query_pool: pool.QueryPool = None
@@ -127,7 +128,7 @@ class Application:
async def run(self):
try:
await self.plugin_connector.initialize_plugins()
await self.plugin_mgr.initialize_plugins()
# 后续可能会允许动态重启其他任务
# 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程
@@ -168,9 +169,6 @@ class Application:
self.logger.error(f'Application runtime fatal exception: {e}')
self.logger.debug(f'Traceback: {traceback.format_exc()}')
def dispose(self):
self.plugin_connector.dispose()
async def print_web_access_info(self):
"""Print access webui tips"""
@@ -197,3 +195,59 @@ class Application:
""".strip()
for line in tips.split('\n'):
self.logger.info(line)
async def reload(
self,
scope: core_entities.LifecycleControlScope,
):
match scope:
case core_entities.LifecycleControlScope.PLATFORM.value:
self.logger.info('Hot reload scope=' + scope)
await self.platform_mgr.shutdown()
self.platform_mgr = im_mgr.PlatformManager(self)
await self.platform_mgr.initialize()
self.task_mgr.create_task(
self.platform_mgr.run(),
name='platform-manager',
scopes=[
core_entities.LifecycleControlScope.APPLICATION,
core_entities.LifecycleControlScope.PLATFORM,
],
)
case core_entities.LifecycleControlScope.PLUGIN.value:
self.logger.info('Hot reload scope=' + scope)
await self.plugin_mgr.destroy_plugins()
# 删除 sys.module 中所有的 plugins/* 下的模块
for mod in list(sys.modules.keys()):
if mod.startswith('plugins.'):
del sys.modules[mod]
self.plugin_mgr = plugin_mgr.PluginManager(self)
await self.plugin_mgr.initialize()
await self.plugin_mgr.initialize_plugins()
await self.plugin_mgr.load_plugins()
await self.plugin_mgr.initialize_plugins()
case core_entities.LifecycleControlScope.PROVIDER.value:
self.logger.info('Hot reload scope=' + scope)
await self.tool_mgr.shutdown()
llm_model_mgr_inst = llm_model_mgr.ModelManager(self)
await llm_model_mgr_inst.initialize()
self.model_mgr = llm_model_mgr_inst
llm_session_mgr_inst = llm_session_mgr.SessionManager(self)
await llm_session_mgr_inst.initialize()
self.sess_mgr = llm_session_mgr_inst
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(self)
await llm_tool_mgr_inst.initialize()
self.tool_mgr = llm_tool_mgr_inst
case _:
pass

View File

@@ -51,8 +51,8 @@ async def main(loop: asyncio.AbstractEventLoop):
import signal
def signal_handler(sig, frame):
app_inst.dispose()
print('[Signal] Program exit.')
# ap.shutdown()
os._exit(0)
signal.signal(signal.SIGINT, signal_handler)

View File

@@ -1,6 +1,18 @@
from __future__ import annotations
import enum
import typing
import datetime
import asyncio
import pydantic.v1 as pydantic
from ..provider import entities as llm_entities
from ..provider.modelmgr import requester
from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter
from ..platform.types import message as platform_message
from ..platform.types import events as platform_events
class LifecycleControlScope(enum.Enum):
@@ -8,3 +20,159 @@ class LifecycleControlScope(enum.Enum):
PLATFORM = 'platform'
PLUGIN = 'plugin'
PROVIDER = 'provider'
class LauncherTypes(enum.Enum):
"""一个请求的发起者类型"""
PERSON = 'person'
"""私聊"""
GROUP = 'group'
"""群聊"""
class Query(pydantic.BaseModel):
"""一次请求的信息封装"""
query_id: int
"""请求ID添加进请求池时生成"""
launcher_type: LauncherTypes
"""会话类型platform处理阶段设置"""
launcher_id: typing.Union[int, str]
"""会话IDplatform处理阶段设置"""
sender_id: typing.Union[int, str]
"""发送者IDplatform处理阶段设置"""
message_event: platform_events.MessageEvent
"""事件platform收到的原始事件"""
message_chain: platform_message.MessageChain
"""消息链platform收到的原始消息链"""
bot_uuid: typing.Optional[str] = None
"""机器人UUID。"""
pipeline_uuid: typing.Optional[str] = None
"""流水线UUID。"""
pipeline_config: typing.Optional[dict[str, typing.Any]] = None
"""流水线配置,由 Pipeline 在运行开始时设置。"""
adapter: msadapter.MessagePlatformAdapter
"""消息平台适配器对象单个app中可能启用了多个消息平台适配器此对象表明发起此query的适配器"""
session: typing.Optional[Session] = None
"""会话对象,由前置处理器阶段设置"""
messages: typing.Optional[list[llm_entities.Message]] = []
"""历史消息列表,由前置处理器阶段设置"""
prompt: typing.Optional[llm_entities.Prompt] = None
"""情景预设内容,由前置处理器阶段设置"""
user_message: typing.Optional[llm_entities.Message] = None
"""此次请求的用户消息对象,由前置处理器阶段设置"""
variables: typing.Optional[dict[str, typing.Any]] = None
"""变量由前置处理器阶段设置。在prompt中嵌入或由 Runner 传递到 LLMOps 平台。"""
use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None
"""使用的对话模型,由前置处理器阶段设置"""
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器阶段设置"""
resp_messages: (
typing.Optional[list[llm_entities.Message]]
| typing.Optional[list[platform_message.MessageChain]]
| typing.Optional[list[llm_entities.MessageChunk]]
) = []
"""由Process阶段生成的回复消息对象列表"""
resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None
"""回复消息链从resp_messages包装而得"""
# ======= 内部保留 =======
current_stage: typing.Optional['pkg.pipeline.pipelinemgr.StageInstContainer'] = None
"""当前所处阶段"""
class Config:
arbitrary_types_allowed = True
# ========== 插件可调用的 API请求 API ==========
def set_variable(self, key: str, value: typing.Any):
"""设置变量"""
if self.variables is None:
self.variables = {}
self.variables[key] = value
def get_variable(self, key: str) -> typing.Any:
"""获取变量"""
if self.variables is None:
return None
return self.variables.get(key)
def get_variables(self) -> dict[str, typing.Any]:
"""获取所有变量"""
if self.variables is None:
return {}
return self.variables
class Conversation(pydantic.BaseModel):
"""对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation但只有一个当前使用的 Conversation"""
prompt: llm_entities.Prompt
messages: list[llm_entities.Message]
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None
use_funcs: typing.Optional[list[tools_entities.LLMFunction]]
pipeline_uuid: str
"""流水线UUID。"""
bot_uuid: str
"""机器人UUID。"""
uuid: typing.Optional[str] = None
"""该对话的 uuid在创建时不会自动生成。而是当使用 Dify API 等由外部管理对话信息的服务时,用于绑定外部的会话。具体如何使用,取决于 Runner。"""
class Config:
arbitrary_types_allowed = True
class Session(pydantic.BaseModel):
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""
launcher_type: LauncherTypes
launcher_id: typing.Union[int, str]
sender_id: typing.Optional[typing.Union[int, str]] = 0
use_prompt_name: typing.Optional[str] = 'default'
using_conversation: typing.Optional[Conversation] = None
conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
semaphore: typing.Optional[asyncio.Semaphore] = None
"""当前会话的信号量,用于限制并发"""
class Config:
arbitrary_types_allowed = True

View File

@@ -1,11 +1,10 @@
from __future__ import annotations
import asyncio
from .. import stage, app
from ...utils import version, proxy, announce
from ...pipeline import pool, controller, pipelinemgr
from ...plugin import connector as plugin_connector
from ...plugin import manager as plugin_mgr
from ...command import cmdmgr
from ...provider.session import sessionmgr as llm_session_mgr
from ...provider.modelmgr import modelmgr as llm_model_mgr
@@ -63,13 +62,10 @@ class BuildAppStage(stage.BootingStage):
ap.persistence_mgr = persistence_mgr_inst
await persistence_mgr_inst.initialize()
async def runtime_disconnect_callback(connector: plugin_connector.PluginRuntimeConnector) -> None:
await asyncio.sleep(3)
await plugin_connector_inst.initialize()
plugin_connector_inst = plugin_connector.PluginRuntimeConnector(ap, runtime_disconnect_callback)
await plugin_connector_inst.initialize()
ap.plugin_connector = plugin_connector_inst
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
await plugin_mgr_inst.initialize()
ap.plugin_mgr = plugin_mgr_inst
await plugin_mgr_inst.load_plugins()
cmd_mgr_inst = cmdmgr.CommandManager(ap)
await cmd_mgr_inst.initialize()

View File

@@ -1,22 +0,0 @@
import sqlalchemy
from .base import Base
class BinaryStorage(Base):
"""Current for plugin use only"""
__tablename__ = 'binary_storages'
unique_key = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
key = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
owner_type = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
owner = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
value = sqlalchemy.Column(sqlalchemy.LargeBinary, nullable=False)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
nullable=False,
server_default=sqlalchemy.func.now(),
onupdate=sqlalchemy.func.now(),
)

View File

@@ -13,8 +13,6 @@ class PluginSetting(Base):
enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True)
priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict)
install_source = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, default='github')
install_info = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,

View File

@@ -44,38 +44,6 @@ class PersistenceManager:
await self.create_tables()
# run migrations
database_version = await self.execute_async(
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
)
database_version = int(database_version.fetchone()[1])
required_database_version = constants.required_database_version
if database_version < required_database_version:
migrations = migration.preregistered_db_migrations
migrations.sort(key=lambda x: x.number)
last_migration_number = database_version
for migration_cls in migrations:
migration_instance = migration_cls(self.ap)
if (
migration_instance.number > database_version
and migration_instance.number <= required_database_version
):
await migration_instance.upgrade()
await self.execute_async(
sqlalchemy.update(metadata.Metadata)
.where(metadata.Metadata.key == 'database_version')
.values({'value': str(migration_instance.number)})
)
last_migration_number = migration_instance.number
self.ap.logger.info(f'Migration {migration_instance.number} completed.')
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
async def create_tables(self):
# create tables
async with self.get_db_engine().connect() as conn:
@@ -119,6 +87,38 @@ class PersistenceManager:
# =================================
# run migrations
database_version = await self.execute_async(
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
)
database_version = int(database_version.fetchone()[1])
required_database_version = constants.required_database_version
if database_version < required_database_version:
migrations = migration.preregistered_db_migrations
migrations.sort(key=lambda x: x.number)
last_migration_number = database_version
for migration_cls in migrations:
migration_instance = migration_cls(self.ap)
if (
migration_instance.number > database_version
and migration_instance.number <= required_database_version
):
await migration_instance.upgrade()
await self.execute_async(
sqlalchemy.update(metadata.Metadata)
.where(metadata.Metadata.key == 'database_version')
.values({'value': str(migration_instance.number)})
)
last_migration_number = migration_instance.number
self.ap.logger.info(f'Migration {migration_instance.number} completed.')
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
async def execute_async(self, *args, **kwargs) -> sqlalchemy.engine.cursor.CursorResult:
async with self.get_db_engine().connect() as conn:
result = await conn.execute(*args, **kwargs)
@@ -128,13 +128,10 @@ class PersistenceManager:
def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
return self.db.get_engine()
def serialize_model(
self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base, masked_columns: list[str] = []
) -> dict:
def serialize_model(self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base) -> dict:
return {
column.name: getattr(data, column.name)
if not isinstance(getattr(data, column.name), (datetime.datetime))
else getattr(data, column.name).isoformat()
for column in model.__table__.columns
if column.name not in masked_columns
}

View File

@@ -212,6 +212,7 @@ class DBMigrateV3Config(migration.DBMigration):
self.ap.instance_config.data['api']['port'] = self.ap.system_cfg.data['http-api']['port']
self.ap.instance_config.data['command'] = {
'prefix': self.ap.command_cfg.data['command-prefix'],
'enable': self.ap.command_cfg.data['command-enable'],
'privilege': self.ap.command_cfg.data['privilege'],
}
self.ap.instance_config.data['concurrency']['pipeline'] = self.ap.system_cfg.data['pipeline-concurrency']

View File

@@ -1,20 +0,0 @@
from .. import migration
@migration.migration_class(4)
class DBMigratePluginConfig(migration.DBMigration):
"""插件配置"""
async def upgrade(self):
"""升级"""
if 'plugin' not in self.ap.instance_config.data:
self.ap.instance_config.data['plugin'] = {
'runtime_ws_url': 'ws://localhost:5400/control/ws',
}
await self.ap.instance_config.dump_config()
async def downgrade(self):
"""降级"""
pass

View File

@@ -1,32 +0,0 @@
import sqlalchemy
from .. import migration
@migration.migration_class(6)
class DBMigratePluginInstallSource(migration.DBMigration):
"""插件安装来源"""
async def upgrade(self):
"""升级"""
# 查询表结构获取所有列名(异步执行 SQL
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text('PRAGMA table_info(plugin_settings);'))
# fetchall() 是同步方法,无需 await
columns = [row[1] for row in result.fetchall()]
# 检查并添加 install_source 列
if 'install_source' not in columns:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text(
"ALTER TABLE plugin_settings ADD COLUMN install_source VARCHAR(255) NOT NULL DEFAULT 'github'"
)
)
# 检查并添加 install_info 列
if 'install_info' not in columns:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text("ALTER TABLE plugin_settings ADD COLUMN install_info JSON NOT NULL DEFAULT '{}'")
)
async def downgrade(self):
"""降级"""
pass

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from .. import stage, entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from ...core import entities as core_entities
@stage.stage_class('BanSessionCheckStage')
@@ -14,7 +14,7 @@ class BanSessionCheckStage(stage.PipelineStage):
async def initialize(self, pipeline_config: dict):
pass
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
found = False
mode = query.pipeline_config['trigger']['access-control']['mode']
@@ -30,6 +30,10 @@ class BanSessionCheckStage(stage.PipelineStage):
if sess == f'{query.launcher_type.value}_{query.launcher_id}':
found = True
break
# 使用 *_id 来表示加白/拉黑某用户的私聊和群聊场景
if sess.startswith('*_') and (sess[2:] == query.launcher_id or sess[2:] == query.sender_id):
found = True
break
ctn = False

View File

@@ -3,11 +3,12 @@ from __future__ import annotations
from ...core import app
from .. import stage, entities
from ...core import entities as core_entities
from . import filter as filter_model, entities as filter_entities
from langbot_plugin.api.entities.builtin.provider import message as provider_message
import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ...provider import entities as llm_entities
from ...platform.types import message as platform_message
from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import filters
importutil.import_modules_in_pkg(filters)
@@ -57,7 +58,7 @@ class ContentFilterStage(stage.PipelineStage):
async def _pre_process(
self,
message: str,
query: pipeline_query.Query,
query: core_entities.Query,
) -> entities.StageProcessResult:
"""请求llm前处理消息
只要有一个不通过就不放行,只放行 PASS 的消息
@@ -85,14 +86,14 @@ class ContentFilterStage(stage.PipelineStage):
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
message = result.replacement
query.message_chain = platform_message.MessageChain([platform_message.Plain(text=message)])
query.message_chain = platform_message.MessageChain(platform_message.Plain(message))
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
async def _post_process(
self,
message: str,
query: pipeline_query.Query,
query: core_entities.Query,
) -> entities.StageProcessResult:
"""请求llm后处理响应
只要是 PASS 或者 MASKED 的就通过此 filter将其 replacement 设置为message进入下一个 filter
@@ -122,7 +123,7 @@ class ContentFilterStage(stage.PipelineStage):
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理"""
if stage_inst_name == 'PreContentFilterStage':
contain_non_text = False
@@ -141,7 +142,7 @@ class ContentFilterStage(stage.PipelineStage):
return await self._pre_process(str(query.message_chain).strip(), query)
elif stage_inst_name == 'PostContentFilterStage':
# 仅处理 query.resp_messages[-1].content 是 str 的情况
if isinstance(query.resp_messages[-1], provider_message.Message) and isinstance(
if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(
query.resp_messages[-1].content, str
):
return await self._post_process(query.resp_messages[-1].content, query)

View File

@@ -1,6 +1,6 @@
import enum
import pydantic
import pydantic.v1 as pydantic
class ResultLevel(enum.Enum):

View File

@@ -3,9 +3,9 @@ from __future__ import annotations
import abc
import typing
from ...core import app
from ...core import app, entities as core_entities
from . import entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_filters: list[typing.Type[ContentFilter]] = []
@@ -60,8 +60,8 @@ class ContentFilter(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def process(self, query: pipeline_query.Query, message: str = None, image_url=None) -> entities.FilterResult:
"""处理消息
async def process(self, query: core_entities.Query, message: str = None, image_url=None) -> entities.FilterResult:
"""Process message
It is divided into two stages, depending on the value of enable_stages.
For content filters, you do not need to consider the stage of the message, you only need to check the message content.

View File

@@ -4,7 +4,8 @@ import aiohttp
from .. import entities
from .. import filter as filter_model
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from ....core import entities as core_entities
BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}'
BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
@@ -26,7 +27,7 @@ class BaiduCloudExamine(filter_model.ContentFilter):
) as resp:
return (await resp.json())['access_token']
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
async with aiohttp.ClientSession() as session:
async with session.post(
BAIDU_EXAMINE_URL.format(await self._get_token()),

View File

@@ -3,7 +3,7 @@ import re
from .. import filter as filter_model
from .. import entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from ....core import entities as core_entities
@filter_model.filter_class('ban-word-filter')
@@ -13,7 +13,7 @@ class BanWordFilter(filter_model.ContentFilter):
async def initialize(self):
pass
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
found = False
for word in self.ap.sensitive_meta.data['words']:

View File

@@ -3,7 +3,7 @@ import re
from .. import entities
from .. import filter as filter_model
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from ....core import entities as core_entities
@filter_model.filter_class('content-ignore')
@@ -16,7 +16,7 @@ class ContentIgnore(filter_model.ContentFilter):
entities.EnableStage.PRE,
]
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
if message.startswith(rule):

View File

@@ -3,10 +3,7 @@ from __future__ import annotations
import asyncio
import traceback
from ..core import app
from ..core import entities as core_entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from ..core import app, entities
class Controller:
@@ -25,19 +22,19 @@ class Controller:
"""事件处理循环"""
try:
while True:
selected_query: pipeline_query.Query = None
selected_query: entities.Query = None
# 取请求
async with self.ap.query_pool:
queries: list[pipeline_query.Query] = self.ap.query_pool.queries
queries: list[entities.Query] = self.ap.query_pool.queries
for query in queries:
session = await self.ap.sess_mgr.get_session(query)
self.ap.logger.debug(f'Checking query {query} session {session}')
if not session._semaphore.locked():
if not session.semaphore.locked():
selected_query = query
await session._semaphore.acquire()
await session.semaphore.acquire()
break
@@ -49,7 +46,7 @@ class Controller:
if selected_query:
async def _process_query(selected_query: pipeline_query.Query):
async def _process_query(selected_query: entities.Query):
async with self.semaphore: # 总并发上限
# find pipeline
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
@@ -62,7 +59,7 @@ class Controller:
await pipeline.run(selected_query)
async with self.ap.query_pool:
(await self.ap.sess_mgr.get_session(selected_query))._semaphore.release()
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
# 通知其他协程,有新的请求可以处理了
self.ap.query_pool.condition.notify_all()
@@ -71,8 +68,8 @@ class Controller:
kind='query',
name=f'query-{selected_query.query_id}',
scopes=[
core_entities.LifecycleControlScope.APPLICATION,
core_entities.LifecycleControlScope.PLATFORM,
entities.LifecycleControlScope.APPLICATION,
entities.LifecycleControlScope.PLATFORM,
],
)

View File

@@ -3,10 +3,10 @@ from __future__ import annotations
import enum
import typing
import pydantic
import pydantic.v1 as pydantic
from ..platform.types import message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ..core import entities
class ResultType(enum.Enum):
@@ -20,7 +20,7 @@ class ResultType(enum.Enum):
class StageProcessResult(pydantic.BaseModel):
result_type: ResultType
new_query: pipeline_query.Query
new_query: entities.Query
user_notice: typing.Optional[
typing.Union[

View File

@@ -5,9 +5,10 @@ import traceback
from . import strategy
from .. import stage, entities
import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ...core import entities as core_entities
from ...platform.types import message as platform_message
from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import strategies
importutil.import_modules_in_pkg(strategies)
@@ -66,8 +67,8 @@ class LongTextProcessStage(stage.PipelineStage):
await self.strategy_impl.initialize()
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
# 检查是否包含非 Plain 组件
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
# Check if it contains non-Plain components
contains_non_plain = False
for msg in query.resp_message_chain[-1]:

View File

@@ -3,9 +3,9 @@ from __future__ import annotations
from .. import strategy as strategy_model
from ....core import entities as core_entities
from ....platform.types import message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.message as platform_message
ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay
Forward = platform_message.Forward
@@ -13,7 +13,7 @@ Forward = platform_message.Forward
@strategy_model.strategy_class('forward')
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
display = ForwardMessageDiaplay(
title='Group chat history',
brief='[Chat history]',

View File

@@ -8,10 +8,10 @@ import re
from PIL import Image, ImageDraw, ImageFont
import functools
from ....platform.types import message as platform_message
from .. import strategy as strategy_model
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ....core import entities as core_entities
@strategy_model.strategy_class('image')
@@ -27,7 +27,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
encoding='utf-8',
)
async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
img_path = self.text_to_image(
text_str=message,
save_as='temp/{}.png'.format(int(time.time())),
@@ -131,7 +131,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
text_str: str,
save_as='temp.png',
width=800,
query: pipeline_query.Query = None,
query: core_entities.Query = None,
):
text_str = text_str.replace('\t', ' ')

View File

@@ -4,9 +4,8 @@ import typing
from ...core import app
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from ...core import entities as core_entities
from ...platform.types import message as platform_message
preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
@@ -50,8 +49,8 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
"""处理长文本
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
"""Process long text
If the text length exceeds the threshold, this method will be called.

View File

@@ -1,9 +1,10 @@
from __future__ import annotations
from .. import stage, entities
from ...core import entities as core_entities
from . import truncator
from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import truncators
importutil.import_modules_in_pkg(truncators)
@@ -28,8 +29,8 @@ class ConversationMessageTruncator(stage.PipelineStage):
else:
raise ValueError(f'Unknown truncator: {use_method}')
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理"""
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""Process"""
query = await self.trun.truncate(query)
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
import typing
import abc
from ...core import app
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from ...core import entities as core_entities, app
preregistered_truncators: list[typing.Type[Truncator]] = []
@@ -47,7 +47,7 @@ class Truncator(abc.ABC):
pass
@abc.abstractmethod
async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
"""截断
一般只需要操作query.messages也可以扩展操作query.prompt, query.user_message。

View File

@@ -1,15 +1,15 @@
from __future__ import annotations
from .. import truncator
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from ....core import entities as core_entities
@truncator.truncator_class('round')
class RoundTruncator(truncator.Truncator):
"""Truncate the conversation message chain to adapt to the LLM message length limit."""
async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
"""截断"""
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
"""Truncate"""
max_round = query.pipeline_config['ai']['local-agent']['max-round']
temp_messages = []

View File

@@ -5,18 +5,14 @@ import traceback
import sqlalchemy
from ..core import app
from ..core import app, entities
from . import entities as pipeline_entities
from ..entity.persistence import pipeline as persistence_pipeline
from . import stage
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.events as events
from ..platform.types import message as platform_message, events as platform_events
from ..plugin import events
from ..utils import importutil
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import (
resprule,
bansess,
@@ -79,17 +75,17 @@ class RuntimePipeline:
self.pipeline_entity = pipeline_entity
self.stage_containers = stage_containers
async def run(self, query: pipeline_query.Query):
async def run(self, query: entities.Query):
query.pipeline_config = self.pipeline_entity.config
await self.process_query(query)
async def _check_output(self, query: pipeline_query.Query, result: pipeline_entities.StageProcessResult):
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
"""检查输出"""
if result.user_notice:
# 处理str类型
if isinstance(result.user_notice, str):
result.user_notice = platform_message.MessageChain([platform_message.Plain(text=result.user_notice)])
result.user_notice = platform_message.MessageChain(platform_message.Plain(result.user_notice))
elif isinstance(result.user_notice, list):
result.user_notice = platform_message.MessageChain(*result.user_notice)
@@ -103,7 +99,7 @@ class RuntimePipeline:
bot_message=query.resp_messages[-1],
message=result.user_notice,
quote_origin=query.pipeline_config['output']['misc']['quote-origin'],
is_final=[msg.is_final for msg in query.resp_messages][0],
is_final=[msg.is_final for msg in query.resp_messages][0]
)
else:
await query.adapter.reply_message(
@@ -121,7 +117,7 @@ class RuntimePipeline:
async def _execute_from_stage(
self,
stage_index: int,
query: pipeline_query.Query,
query: entities.Query,
):
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
@@ -148,7 +144,7 @@ class RuntimePipeline:
while i < len(self.stage_containers):
stage_container = self.stage_containers[i]
query.current_stage_name = stage_container.inst_name # 标记到 Query 对象里
query.current_stage = stage_container # 标记到 Query 对象里
result = stage_container.inst.process(query, stage_container.inst_name)
@@ -185,26 +181,26 @@ class RuntimePipeline:
i += 1
async def process_query(self, query: pipeline_query.Query):
async def process_query(self, query: entities.Query):
"""处理请求"""
try:
# ======== 触发 MessageReceived 事件 ========
event_type = (
events.PersonMessageReceived
if query.launcher_type == provider_session.LauncherTypes.PERSON
if query.launcher_type == entities.LauncherTypes.PERSON
else events.GroupMessageReceived
)
event_obj = event_type(
query=query,
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
message_chain=query.message_chain,
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_type(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
message_chain=query.message_chain,
query=query,
)
)
event_ctx = await self.ap.plugin_connector.emit_event(event_obj)
if event_ctx.is_prevented_default():
return
@@ -212,12 +208,11 @@ class RuntimePipeline:
await self._execute_from_stage(0, query)
except Exception as e:
inst_name = query.current_stage_name if query.current_stage_name else 'unknown'
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}')
self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
finally:
self.ap.logger.debug(f'Query {query.query_id} processed')
del self.ap.query_pool.cached_queries[query.query_id]
class PipelineManager:

View File

@@ -3,11 +3,10 @@ from __future__ import annotations
import asyncio
import typing
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from ..core import entities
from ..platform import adapter as msadapter
from ..platform.types import message as platform_message
from ..platform.types import events as platform_events
class QueryPool:
@@ -17,10 +16,7 @@ class QueryPool:
pool_lock: asyncio.Lock
queries: list[pipeline_query.Query]
cached_queries: dict[int, pipeline_query.Query]
"""Cached queries, used for plugin backward api call, will be removed after the query completely processed"""
queries: list[entities.Query]
condition: asyncio.Condition
@@ -28,38 +24,34 @@ class QueryPool:
self.query_id_counter = 0
self.pool_lock = asyncio.Lock()
self.queries = []
self.cached_queries = {}
self.condition = asyncio.Condition(self.pool_lock)
async def add_query(
self,
bot_uuid: str,
launcher_type: provider_session.LauncherTypes,
launcher_type: entities.LauncherTypes,
launcher_id: typing.Union[int, str],
sender_id: typing.Union[int, str],
message_event: platform_events.MessageEvent,
message_chain: platform_message.MessageChain,
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
adapter: msadapter.MessagePlatformAdapter,
pipeline_uuid: typing.Optional[str] = None,
) -> pipeline_query.Query:
) -> entities.Query:
async with self.condition:
query_id = self.query_id_counter
query = pipeline_query.Query(
query = entities.Query(
bot_uuid=bot_uuid,
query_id=query_id,
query_id=self.query_id_counter,
launcher_type=launcher_type,
launcher_id=launcher_id,
sender_id=sender_id,
message_event=message_event,
message_chain=message_chain,
variables={},
resp_messages=[],
resp_message_chain=[],
adapter=adapter,
pipeline_uuid=pipeline_uuid,
)
self.queries.append(query)
self.cached_queries[query_id] = query
self.query_id_counter += 1
self.condition.notify_all()

View File

@@ -3,10 +3,10 @@ from __future__ import annotations
import datetime
from .. import stage, entities
from langbot_plugin.api.entities.builtin.provider import message as provider_message
import langbot_plugin.api.entities.events as events
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from ...core import entities as core_entities
from ...provider import entities as llm_entities
from ...plugin import events
from ...platform.types import message as platform_message
@stage.stage_class('PreProcessor')
@@ -26,7 +26,7 @@ class PreProcessor(stage.PipelineStage):
async def process(
self,
query: pipeline_query.Query,
query: core_entities.Query,
stage_inst_name: str,
) -> entities.StageProcessResult:
"""Process"""
@@ -49,73 +49,80 @@ class PreProcessor(stage.PipelineStage):
query.bot_uuid,
)
# 设置query
conversation.use_llm_model = llm_model
# Set query
query.session = session
query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy()
query.use_llm_model_uuid = llm_model.model_entity.uuid
query.use_llm_model = llm_model
if selected_runner == 'local-agent':
query.use_funcs = []
query.use_funcs = (
conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('func_call') else None
)
if llm_model.model_entity.abilities.__contains__('func_call'):
query.use_funcs = await self.ap.tool_mgr.get_all_tools()
variables = {
query.variables = {
'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
'conversation_id': conversation.uuid,
'msg_create_time': (
int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp())
),
}
query.variables.update(variables)
# Check if this model supports vision, if not, remove all images
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
if selected_runner == 'local-agent' and not llm_model.model_entity.abilities.__contains__('vision'):
if selected_runner == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'):
for msg in query.messages:
if isinstance(msg.content, list):
for me in msg.content:
if me.type == 'image_url':
msg.content.remove(me)
content_list: list[provider_message.ContentElement] = []
content_list: list[llm_entities.ContentElement] = []
plain_text = ''
qoute_msg = query.pipeline_config['trigger'].get('misc', '').get('combine-quote-message')
# tidy the content_list
# combine all text content into one, and put it in the first position
for me in query.message_chain:
if isinstance(me, platform_message.Plain):
content_list.append(provider_message.ContentElement.from_text(me.text))
plain_text += me.text
elif isinstance(me, platform_message.Image):
if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__(
'vision'
):
if me.base64 is not None:
content_list.append(provider_message.ContentElement.from_image_base64(me.base64))
content_list.append(llm_entities.ContentElement.from_image_base64(me.base64))
elif isinstance(me, platform_message.Quote) and qoute_msg:
for msg in me.origin:
if isinstance(msg, platform_message.Plain):
content_list.append(provider_message.ContentElement.from_text(msg.text))
content_list.append(llm_entities.ContentElement.from_text(msg.text))
elif isinstance(msg, platform_message.Image):
if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__(
'vision'
):
if msg.base64 is not None:
content_list.append(provider_message.ContentElement.from_image_base64(msg.base64))
content_list.append(llm_entities.ContentElement.from_image_base64(msg.base64))
content_list.insert(0, llm_entities.ContentElement.from_text(plain_text))
query.variables['user_message_text'] = plain_text
query.user_message = provider_message.Message(role='user', content=content_list)
# =========== 触发事件 PromptPreProcessing
query.user_message = llm_entities.Message(role='user', content=content_list)
# =========== Trigger event PromptPreProcessing
event = events.PromptPreProcessing(
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
default_prompt=query.prompt.messages,
prompt=query.messages,
query=query,
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PromptPreProcessing(
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
default_prompt=query.prompt.messages,
prompt=query.messages,
query=query,
)
)
event_ctx = await self.ap.plugin_connector.emit_event(event)
query.prompt.messages = event_ctx.event.default_prompt
query.messages = event_ctx.event.prompt

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
import abc
from ...core import app
from ...core import entities as core_entities
from .. import entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class MessageHandler(metaclass=abc.ABCMeta):
@@ -19,7 +19,7 @@ class MessageHandler(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def handle(
self,
query: pipeline_query.Query,
query: core_entities.Query,
) -> entities.StageProcessResult:
raise NotImplementedError

View File

@@ -7,15 +7,13 @@ import traceback
from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....provider import runner as runner_module
from ....plugin import events
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.events as events
from ....platform.types import message as platform_message
from ....utils import importutil
from ....provider import runners
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
importutil.import_modules_in_pkg(runners)
@@ -23,7 +21,7 @@ importutil.import_modules_in_pkg(runners)
class ChatMessageHandler(handler.MessageHandler):
async def handle(
self,
query: pipeline_query.Query,
query: core_entities.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理"""
# 调API
@@ -32,20 +30,20 @@ class ChatMessageHandler(handler.MessageHandler):
# 触发插件事件
event_class = (
events.PersonNormalMessageReceived
if query.launcher_type == provider_session.LauncherTypes.PERSON
if query.launcher_type == core_entities.LauncherTypes.PERSON
else events.GroupNormalMessageReceived
)
event = event_class(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
text_message=str(query.message_chain),
query=query,
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_class(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
text_message=str(query.message_chain),
query=query,
)
)
event_ctx = await self.ap.plugin_connector.emit_event(event)
is_create_card = False # 判断下是否需要创建流式卡片
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
mc = platform_message.MessageChain(event_ctx.event.reply)
@@ -74,14 +72,17 @@ class ChatMessageHandler(handler.MessageHandler):
raise ValueError(f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}')
if is_stream:
resp_message_id = uuid.uuid4()
await query.adapter.create_message_card(str(resp_message_id), query.message_event)
async for result in runner.run(query):
result.resp_message_id = str(resp_message_id)
if query.resp_messages:
query.resp_messages.pop()
if query.resp_message_chain:
query.resp_message_chain.pop()
# 此时连接外部 AI 服务正常,创建卡片
if not is_create_card: # 只有不是第一次才创建卡片
await query.adapter.create_message_card(str(resp_message_id), query.message_event)
is_create_card = True
query.resp_messages.append(result)
self.ap.logger.info(f'对话({query.query_id})流式响应: {self.cut_str(result.readable_str())}')
@@ -119,4 +120,4 @@ class ChatMessageHandler(handler.MessageHandler):
)
finally:
# TODO statistics
pass
pass

View File

@@ -4,17 +4,16 @@ import typing
from .. import handler
from ... import entities
import langbot_plugin.api.entities.builtin.provider.message as provider_message
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.events as events
from ....core import entities as core_entities
from ....provider import entities as llm_entities
from ....plugin import events
from ....platform.types import message as platform_message
class CommandHandler(handler.MessageHandler):
async def handle(
self,
query: pipeline_query.Query,
query: core_entities.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""Process"""
@@ -29,23 +28,23 @@ class CommandHandler(handler.MessageHandler):
event_class = (
events.PersonCommandSent
if query.launcher_type == provider_session.LauncherTypes.PERSON
if query.launcher_type == core_entities.LauncherTypes.PERSON
else events.GroupCommandSent
)
event = event_class(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
command=spt[0],
params=spt[1:] if len(spt) > 1 else [],
text_message=str(query.message_chain),
is_admin=(privilege == 2),
query=query,
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_class(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
command=spt[0],
params=spt[1:] if len(spt) > 1 else [],
text_message=str(query.message_chain),
is_admin=(privilege == 2),
query=query,
)
)
event_ctx = await self.ap.plugin_connector.emit_event(event)
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
mc = platform_message.MessageChain(event_ctx.event.reply)
@@ -65,7 +64,7 @@ class CommandHandler(handler.MessageHandler):
async for ret in self.ap.cmd_mgr.execute(command_text=command_text, query=query, session=session):
if ret.error is not None:
query.resp_messages.append(
provider_message.Message(
llm_entities.Message(
role='command',
content=str(ret.error),
)
@@ -75,16 +74,16 @@ class CommandHandler(handler.MessageHandler):
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
elif ret.text is not None or ret.image_url is not None:
content: list[provider_message.ContentElement] = []
content: list[llm_entities.ContentElement] = []
if ret.text is not None:
content.append(provider_message.ContentElement.from_text(ret.text))
content.append(llm_entities.ContentElement.from_text(ret.text))
if ret.image_url is not None:
content.append(provider_message.ContentElement.from_image_url(ret.image_url))
content.append(llm_entities.ContentElement.from_image_url(ret.image_url))
query.resp_messages.append(
provider_message.Message(
llm_entities.Message(
role='command',
content=content,
)

View File

@@ -1,10 +1,10 @@
from __future__ import annotations
from ...core import entities as core_entities
from . import handler
from .handlers import chat, command
from .. import entities
from .. import stage
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('MessageProcessor')
@@ -30,7 +30,7 @@ class Processor(stage.PipelineStage):
async def process(
self,
query: pipeline_query.Query,
query: core_entities.Query,
stage_inst_name: str,
) -> entities.StageProcessResult:
"""Process"""
@@ -42,12 +42,14 @@ class Processor(stage.PipelineStage):
async def generator():
cmd_prefix = self.ap.instance_config.data['command']['prefix']
cmd_enable = self.ap.instance_config.data['command'].get('enable', True)
if any(message_text.startswith(prefix) for prefix in cmd_prefix):
async for result in self.cmd_handler.handle(query):
yield result
if cmd_enable and any(message_text.startswith(prefix) for prefix in cmd_prefix):
handler_to_use = self.cmd_handler
else:
async for result in self.chat_handler.handle(query):
yield result
handler_to_use = self.chat_handler
async for result in handler_to_use.handle(query):
yield result
return generator()

View File

@@ -2,8 +2,7 @@ from __future__ import annotations
import abc
import typing
from ...core import app
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from ...core import app, entities as core_entities
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
@@ -34,7 +33,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def require_access(
self,
query: pipeline_query.Query,
query: core_entities.Query,
launcher_type: str,
launcher_id: typing.Union[int, str],
) -> bool:
@@ -54,7 +53,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def release_access(
self,
query: pipeline_query.Query,
query: core_entities.Query,
launcher_type: str,
launcher_id: typing.Union[int, str],
):

View File

@@ -3,7 +3,7 @@ import asyncio
import time
import typing
from .. import algo
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from ....core import entities as core_entities
# 固定窗口算法
@@ -32,7 +32,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
async def require_access(
self,
query: pipeline_query.Query,
query: core_entities.Query,
launcher_type: str,
launcher_id: typing.Union[int, str],
) -> bool:
@@ -91,7 +91,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
async def release_access(
self,
query: pipeline_query.Query,
query: core_entities.Query,
launcher_type: str,
launcher_id: typing.Union[int, str],
):

View File

@@ -4,10 +4,9 @@ import typing
from .. import entities, stage
from . import algo
from ...core import entities as core_entities
from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import algos
importutil.import_modules_in_pkg(algos)
@@ -40,7 +39,7 @@ class RateLimit(stage.PipelineStage):
async def process(
self,
query: pipeline_query.Query,
query: core_entities.Query,
stage_inst_name: str,
) -> typing.Union[
entities.StageProcessResult,

View File

@@ -4,19 +4,22 @@ import random
import asyncio
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.provider.message as provider_message
from ...platform.types import events as platform_events
from ...platform.types import message as platform_message
from ...provider import entities as llm_entities
from .. import stage, entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from ...core import entities as core_entities
@stage.stage_class('SendResponseBackStage')
class SendResponseBackStage(stage.PipelineStage):
"""发送响应消息"""
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理"""
random_range = (
@@ -37,7 +40,7 @@ class SendResponseBackStage(stage.PipelineStage):
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
has_chunks = any(isinstance(msg, provider_message.MessageChunk) for msg in query.resp_messages)
has_chunks = any(isinstance(msg, llm_entities.MessageChunk) for msg in query.resp_messages)
# TODO 命令与流式的兼容性问题
if await query.adapter.is_stream_output_supported() and has_chunks:
is_final = [msg.is_final for msg in query.resp_messages][0]

View File

@@ -1,6 +1,6 @@
import pydantic
import pydantic.v1 as pydantic
import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ...platform.types import message as platform_message
class RuleJudgeResult(pydantic.BaseModel):

View File

@@ -4,10 +4,9 @@ from __future__ import annotations
from . import rule
from .. import stage, entities
from ...core import entities as core_entities
from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import rules
importutil.import_modules_in_pkg(rules)
@@ -33,7 +32,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
await rule_inst.initialize()
self.rule_matchers.append(rule_inst)
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if query.launcher_type.value != 'group': # 只处理群消息
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -2,11 +2,10 @@ from __future__ import annotations
import abc
import typing
from ...core import app
from ...core import app, entities as core_entities
from . import entities
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from ...platform.types import message as platform_message
preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
@@ -40,7 +39,7 @@ class GroupRespondRule(metaclass=abc.ABCMeta):
message_text: str,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: pipeline_query.Query,
query: core_entities.Query,
) -> entities.RuleJudgeResult:
"""判断消息是否匹配规则"""
raise NotImplementedError

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
from .. import rule as rule_model
from .. import entities
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class('at-bot')
@@ -14,28 +14,19 @@ class AtBotRule(rule_model.GroupRespondRule):
message_text: str,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: pipeline_query.Query,
query: core_entities.Query,
) -> entities.RuleJudgeResult:
def remove_at(message_chain: platform_message.MessageChain):
for component in message_chain.root:
if isinstance(component, platform_message.At) and component.target == query.adapter.bot_account_id:
message_chain.remove(component)
break
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
remove_at(message_chain)
remove_at(message_chain) # 回复消息时会at两次检查并删除重复的
if message_chain.has(
platform_message.At(query.adapter.bot_account_id)
): # 回复消息时会at两次检查并删除重复的
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
# if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
# message_chain.remove(platform_message.At(query.adapter.bot_account_id))
# if message_chain.has(
# platform_message.At(query.adapter.bot_account_id)
# ): # 回复消息时会at两次检查并删除重复的
# message_chain.remove(platform_message.At(query.adapter.bot_account_id))
# return entities.RuleJudgeResult(
# matching=True,
# replacement=message_chain,
# )
return entities.RuleJudgeResult(
matching=True,
replacement=message_chain,
)
return entities.RuleJudgeResult(matching=False, replacement=message_chain)

View File

@@ -1,7 +1,7 @@
from .. import rule as rule_model
from .. import entities
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class('prefix')
@@ -11,7 +11,7 @@ class PrefixRule(rule_model.GroupRespondRule):
message_text: str,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: pipeline_query.Query,
query: core_entities.Query,
) -> entities.RuleJudgeResult:
prefixes = rule_dict['prefix']

View File

@@ -3,8 +3,8 @@ import random
from .. import rule as rule_model
from .. import entities
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class('random')
@@ -14,7 +14,7 @@ class RandomRespRule(rule_model.GroupRespondRule):
message_text: str,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: pipeline_query.Query,
query: core_entities.Query,
) -> entities.RuleJudgeResult:
random_rate = rule_dict['random']

View File

@@ -3,8 +3,8 @@ import re
from .. import rule as rule_model
from .. import entities
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class('regexp')
@@ -14,7 +14,7 @@ class RegExpRule(rule_model.GroupRespondRule):
message_text: str,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: pipeline_query.Query,
query: core_entities.Query,
) -> entities.RuleJudgeResult:
regexps = rule_dict['regexp']

View File

@@ -3,9 +3,8 @@ from __future__ import annotations
import abc
import typing
from ..core import app
from ..core import app, entities as core_entities
from . import entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_stages: dict[str, type[PipelineStage]] = {}
@@ -34,7 +33,7 @@ class PipelineStage(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def process(
self,
query: pipeline_query.Query,
query: core_entities.Query,
stage_inst_name: str,
) -> typing.Union[
entities.StageProcessResult,

View File

@@ -2,12 +2,12 @@ from __future__ import annotations
import typing
from ...core import entities as core_entities
from .. import entities
from .. import stage
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.events as events
from ...plugin import events
from ...platform.types import message as platform_message
@stage.stage_class('ResponseWrapper')
@@ -25,7 +25,7 @@ class ResponseWrapper(stage.PipelineStage):
async def process(
self,
query: pipeline_query.Query,
query: core_entities.Query,
stage_inst_name: str,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理"""
@@ -58,22 +58,21 @@ class ResponseWrapper(stage.PipelineStage):
reply_text = str(result.get_content_platform_message_chain())
# ============= 触发插件事件 ===============
event = events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls]
if result.tool_calls is not None
else [],
query=query,
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls]
if result.tool_calls is not None
else [],
query=query,
)
)
event_ctx = await self.ap.plugin_connector.emit_event(event)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
@@ -97,26 +96,26 @@ class ResponseWrapper(stage.PipelineStage):
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain.append(
platform_message.MessageChain([platform_message.Plain(text=reply_text)])
platform_message.MessageChain([platform_message.Plain(reply_text)])
)
if query.pipeline_config['output']['misc']['track-function-calls']:
event = events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls]
if result.tool_calls is not None
else [],
query=query,
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
session=session,
prefix='',
response_text=reply_text,
finish_reason='stop',
funcs_called=[fc.function.name for fc in result.tool_calls]
if result.tool_calls is not None
else [],
query=query,
)
)
event_ctx = await self.ap.plugin_connector.emit_event(event)
if event_ctx.is_prevented_default():
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
@@ -125,12 +124,12 @@ class ResponseWrapper(stage.PipelineStage):
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(
platform_message.MessageChain(text=event_ctx.event.reply)
platform_message.MessageChain(event_ctx.event.reply)
)
else:
query.resp_message_chain.append(
platform_message.MessageChain([platform_message.Plain(text=reply_text)])
platform_message.MessageChain([platform_message.Plain(reply_text)])
)
yield entities.StageProcessResult(

190
pkg/platform/adapter.py Normal file
View File

@@ -0,0 +1,190 @@
from __future__ import annotations
# MessageSource的适配器
import typing
import abc
from ..core import app
from .types import message as platform_message
from .types import events as platform_events
from .logger import EventLogger
class MessagePlatformAdapter(metaclass=abc.ABCMeta):
"""消息平台适配器基类"""
name: str
bot_account_id: int
"""机器人账号ID需要在初始化时设置"""
config: dict
ap: app.Application
logger: EventLogger
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
"""初始化适配器
Args:
config (dict): 对应的配置
ap (app.Application): 应用上下文
"""
self.config = config
self.ap = ap
self.logger = logger
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
"""主动发送消息
Args:
target_type (str): 目标类型,`person`或`group`
target_id (str): 目标ID
message (platform.types.MessageChain): 消息链
"""
raise NotImplementedError
async def reply_message(
self,
message_source: platform_events.MessageEvent,
message: platform_message.MessageChain,
quote_origin: bool = False,
):
"""回复消息
Args:
message_source (platform.types.MessageEvent): 消息源事件
message (platform.types.MessageChain): 消息链
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
"""
raise NotImplementedError
async def reply_message_chunk(
self,
message_source: platform_events.MessageEvent,
bot_message: dict,
message: platform_message.MessageChain,
quote_origin: bool = False,
is_final: bool = False,
):
"""回复消息(流式输出)
Args:
message_source (platform.types.MessageEvent): 消息源事件
message_id (int): 消息ID
message (platform.types.MessageChain): 消息链
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
is_final (bool, optional): 流式是否结束. Defaults to False.
"""
raise NotImplementedError
async def create_message_card(self, message_id: typing.Type[str, int], event: platform_events.MessageEvent) -> bool:
"""创建卡片消息
Args:
message_id (str): 消息ID
event (platform_events.MessageEvent): 消息源事件
"""
return False
async def is_muted(self, group_id: int) -> bool:
"""获取账号是否在指定群被禁言"""
raise NotImplementedError
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
"""注册事件监听器
Args:
event_type (typing.Type[platform.types.Event]): 事件类型
callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件
"""
raise NotImplementedError
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
"""注销事件监听器
Args:
event_type (typing.Type[platform.types.Event]): 事件类型
callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件
"""
raise NotImplementedError
async def run_async(self):
"""异步运行"""
raise NotImplementedError
async def is_stream_output_supported(self) -> bool:
"""是否支持流式输出"""
return False
async def kill(self) -> bool:
"""关闭适配器
Returns:
bool: 是否成功关闭热重载时若此函数返回False则不会重载MessageSource底层
"""
raise NotImplementedError
class MessageConverter:
"""消息链转换器基类"""
@staticmethod
def yiri2target(message_chain: platform_message.MessageChain):
"""将源平台消息链转换为目标平台消息链
Args:
message_chain (platform.types.MessageChain): 源平台消息链
Returns:
typing.Any: 目标平台消息链
"""
raise NotImplementedError
@staticmethod
def target2yiri(message_chain: typing.Any) -> platform_message.MessageChain:
"""将目标平台消息链转换为源平台消息链
Args:
message_chain (typing.Any): 目标平台消息链
Returns:
platform.types.MessageChain: 源平台消息链
"""
raise NotImplementedError
class EventConverter:
"""事件转换器基类"""
@staticmethod
def yiri2target(event: typing.Type[platform_events.Event]):
"""将源平台事件转换为目标平台事件
Args:
event (typing.Type[platform.types.Event]): 源平台事件
Returns:
typing.Any: 目标平台事件
"""
raise NotImplementedError
@staticmethod
def target2yiri(event: typing.Any) -> platform_events.Event:
"""将目标平台事件的调用参数转换为源平台的事件参数对象
Args:
event (typing.Any): 目标平台事件
Returns:
typing.Type[platform.types.Event]: 源平台事件
"""
raise NotImplementedError

14
pkg/platform/adapter.yaml Normal file
View File

@@ -0,0 +1,14 @@
apiVersion: v1
kind: ComponentTemplate
metadata:
name: MessagePlatformAdapter
label:
en_US: Message Platform Adapter
zh_Hans: 消息平台适配器模板类
spec:
type:
- python
execution:
python:
path: ./adapter.py
attr: MessagePlatformAdapter

View File

@@ -1,10 +1,15 @@
from __future__ import annotations
import sys
import asyncio
import traceback
import sqlalchemy
# FriendMessage, Image, MessageChain, Plain
from . import adapter as msadapter
from ..core import app, entities as core_entities, taskmgr
from .types import events as platform_events, message as platform_message
from ..discover import engine
@@ -14,10 +19,10 @@ from ..entity.errors import platform as platform_errors
from .logger import EventLogger
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
# 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题
from . import types as mirai
sys.modules['mirai'] = mirai
class RuntimeBot:
@@ -29,7 +34,7 @@ class RuntimeBot:
enable: bool
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter
adapter: msadapter.MessagePlatformAdapter
task_wrapper: taskmgr.TaskWrapper
@@ -41,7 +46,7 @@ class RuntimeBot:
self,
ap: app.Application,
bot_entity: persistence_bot.Bot,
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
adapter: msadapter.MessagePlatformAdapter,
logger: EventLogger,
):
self.ap = ap
@@ -54,7 +59,7 @@ class RuntimeBot:
async def initialize(self):
async def on_friend_message(
event: platform_events.FriendMessage,
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
adapter: msadapter.MessagePlatformAdapter,
):
image_components = [
component for component in event.message_chain if isinstance(component, platform_message.Image)
@@ -68,7 +73,7 @@ class RuntimeBot:
await self.ap.query_pool.add_query(
bot_uuid=self.bot_entity.uuid,
launcher_type=provider_session.LauncherTypes.PERSON,
launcher_type=core_entities.LauncherTypes.PERSON,
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
@@ -79,7 +84,7 @@ class RuntimeBot:
async def on_group_message(
event: platform_events.GroupMessage,
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
adapter: msadapter.MessagePlatformAdapter,
):
image_components = [
component for component in event.message_chain if isinstance(component, platform_message.Image)
@@ -93,7 +98,7 @@ class RuntimeBot:
await self.ap.query_pool.add_query(
bot_uuid=self.bot_entity.uuid,
launcher_type=provider_session.LauncherTypes.GROUP,
launcher_type=core_entities.LauncherTypes.GROUP,
launcher_id=event.group.id,
sender_id=event.sender.id,
message_event=event,
@@ -148,7 +153,7 @@ class PlatformManager:
adapter_components: list[engine.Component]
adapter_dict: dict[str, type[abstract_platform_adapter.AbstractMessagePlatformAdapter]]
adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]]
def __init__(self, ap: app.Application = None):
self.ap = ap
@@ -158,7 +163,7 @@ class PlatformManager:
async def initialize(self):
self.adapter_components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter')
adapter_dict: dict[str, type[abstract_platform_adapter.AbstractMessagePlatformAdapter]] = {}
adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] = {}
for component in self.adapter_components:
adapter_dict[component.metadata.name] = component.get_python_component_class()
self.adapter_dict = adapter_dict
@@ -169,9 +174,8 @@ class PlatformManager:
webchat_logger = EventLogger(name='webchat-adapter', ap=self.ap)
webchat_adapter_inst = webchat_adapter_class(
{},
self.ap,
webchat_logger,
ap=self.ap,
is_stream=False,
)
self.webchat_proxy_bot = RuntimeBot(
@@ -191,7 +195,7 @@ class PlatformManager:
await self.load_bots_from_db()
def get_running_adapters(self) -> list[abstract_platform_adapter.AbstractMessagePlatformAdapter]:
def get_running_adapters(self) -> list[msadapter.MessagePlatformAdapter]:
return [bot.adapter for bot in self.bots if bot.enable]
async def load_bots_from_db(self):
@@ -229,6 +233,7 @@ class PlatformManager:
adapter_inst = self.adapter_dict[bot_entity.adapter](
bot_entity.adapter_config,
self.ap,
logger,
)
@@ -271,6 +276,43 @@ class PlatformManager:
return component
return None
async def write_back_config(
self,
adapter_name: str,
adapter_inst: msadapter.MessagePlatformAdapter,
config: dict,
):
# index = -2
# for i, adapter in enumerate(self.adapters):
# if adapter == adapter_inst:
# index = i
# break
# if index == -2:
# raise Exception('平台适配器未找到')
# # 只修改启用的适配器
# real_index = -1
# for i, adapter in enumerate(self.ap.platform_cfg.data['platform-adapters']):
# if adapter['enable']:
# index -= 1
# if index == -1:
# real_index = i
# break
# new_cfg = {
# 'adapter': adapter_name,
# 'enable': True,
# **config
# }
# self.ap.platform_cfg.data['platform-adapters'][real_index] = new_cfg
# await self.ap.platform_cfg.dump_config()
# TODO implement this
pass
async def run(self):
# This method will only be called when the application launching
await self.webchat_proxy_bot.run()

View File

@@ -9,8 +9,7 @@ import traceback
import uuid
from ..core import app
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_event_logger
from .types import message as platform_message
class EventLogLevel(enum.Enum):
@@ -56,7 +55,7 @@ MAX_LOG_COUNT = 200
DELETE_COUNT_PER_TIME = 50
class EventLogger(abstract_platform_event_logger.AbstractEventLogger):
class EventLogger:
"""used for logging bot events"""
ap: app.Application

View File

@@ -5,17 +5,17 @@ import traceback
import datetime
import aiocqhttp
import pydantic
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from .. import adapter
from ...core import app
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities
from ...utils import image
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
from ..logger import EventLogger
class AiocqhttpMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
class AiocqhttpMessageConverter(adapter.MessageConverter):
@staticmethod
async def yiri2target(
message_chain: platform_message.MessageChain,
@@ -266,21 +266,20 @@ class AiocqhttpMessageConverter(abstract_platform_adapter.AbstractMessageConvert
await process_message_data(msg_data, reply_list)
reply_msg = platform_message.Quote(
message_id=msg.data['id'], sender_id=msg_datas['user_id'], origin=reply_list
message_id=msg.data['id'], sender_id=msg_datas['sender']['user_id'], origin=reply_list
)
yiri_msg_list.append(reply_msg)
elif msg.type == 'file':
pass
# file_name = msg.data['file']
# file_id = msg.data['file_id']
# file_data = await bot.get_file(file_id=file_id)
# file_name = file_data.get('file_name')
# file_path = file_data.get('file')
# _ = file_path
# file_url = file_data.get('file_url')
# file_size = file_data.get('file_size')
# yiri_msg_list.append(platform_message.File(id=file_id, name=file_name,url=file_url,size=file_size))
# 这里下载所有文件会导致下载文件过多,暂时不下载
# elif msg.type == 'file':
# # file_name = msg.data['file']
# file_id = msg.data['file_id']
# file_data = await bot.get_file(file_id=file_id)
# file_name = file_data.get('file_name')
# file_path = file_data.get('file')
# file_url = file_data.get('file_url')
# file_size = file_data.get('file_size')
# yiri_msg_list.append(platform_message.File(id=file_id, name=file_name,url=file_url,size=file_size))
elif msg.type == 'face':
face_id = msg.data['id']
face_name = msg.data['raw']['faceText']
@@ -299,7 +298,7 @@ class AiocqhttpMessageConverter(abstract_platform_adapter.AbstractMessageConvert
return chain
class AiocqhttpEventConverter(abstract_platform_adapter.AbstractEventConverter):
class AiocqhttpEventConverter(adapter.EventConverter):
@staticmethod
async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int):
return event.source_platform_object
@@ -349,19 +348,23 @@ class AiocqhttpEventConverter(abstract_platform_adapter.AbstractEventConverter):
)
class AiocqhttpAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: aiocqhttp.CQHttp = pydantic.Field(exclude=True, default_factory=aiocqhttp.CQHttp)
class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
bot: aiocqhttp.CQHttp
bot_account_id: int
message_converter: AiocqhttpMessageConverter = AiocqhttpMessageConverter()
event_converter: AiocqhttpEventConverter = AiocqhttpEventConverter()
config: dict
ap: app.Application
on_websocket_connection_event_cache: typing.List[typing.Callable[[aiocqhttp.Event], None]] = []
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
super().__init__(
config=config,
logger=logger,
)
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
self.config = config
self.logger = logger
async def shutdown_trigger_placeholder():
while True:
@@ -369,6 +372,7 @@ class AiocqhttpAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
self.config['shutdown_trigger'] = shutdown_trigger_placeholder
self.ap = ap
self.on_websocket_connection_event_cache = []
if 'access-token' in config:
@@ -404,9 +408,7 @@ class AiocqhttpAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
async def on_message(event: aiocqhttp.Event):
self.bot_account_id = event.self_id
@@ -437,9 +439,7 @@ class AiocqhttpAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter)
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
return super().unregister_listener(event_type, callback)

View File

@@ -1,16 +1,19 @@
from re import S
import traceback
import typing
from libs.dingtalk_api.dingtalkevent import DingTalkEvent
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from pkg.platform.types import message as platform_message
from pkg.platform.adapter import MessagePlatformAdapter
from .. import adapter
from ...core import app
from ..types import events as platform_events
from ..types import entities as platform_entities
from libs.dingtalk_api.api import DingTalkClient
import datetime
from ..logger import EventLogger
class DingTalkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
class DingTalkMessageConverter(adapter.MessageConverter):
@staticmethod
async def yiri2target(message_chain: platform_message.MessageChain):
content = ''
@@ -20,6 +23,9 @@ class DingTalkMessageConverter(abstract_platform_adapter.AbstractMessageConverte
at = True
if type(msg) is platform_message.Plain:
content += msg.text
if type(msg) is platform_message.Forward:
for node in msg.node_list:
content += (await DingTalkMessageConverter.yiri2target(node.message_chain))[0]
return content, at
@staticmethod
@@ -46,7 +52,7 @@ class DingTalkMessageConverter(abstract_platform_adapter.AbstractMessageConverte
return chain
class DingTalkEventConverter(abstract_platform_adapter.AbstractEventConverter):
class DingTalkEventConverter(adapter.EventConverter):
@staticmethod
async def yiri2target(event: platform_events.MessageEvent):
return event.source_platform_object
@@ -58,7 +64,7 @@ class DingTalkEventConverter(abstract_platform_adapter.AbstractEventConverter):
if event.conversation == 'FriendMessage':
return platform_events.FriendMessage(
sender=platform_entities.Friend(
id=event.incoming_message.sender_id,
id=event.incoming_message.sender_staff_id,
nickname=event.incoming_message.sender_nick,
remark='',
),
@@ -68,7 +74,7 @@ class DingTalkEventConverter(abstract_platform_adapter.AbstractEventConverter):
)
elif event.conversation == 'GroupMessage':
sender = platform_entities.GroupMember(
id=event.incoming_message.sender_id,
id=event.incoming_message.sender_staff_id,
member_name=event.incoming_message.sender_nick,
permission='MEMBER',
group=platform_entities.Group(
@@ -90,19 +96,19 @@ class DingTalkEventConverter(abstract_platform_adapter.AbstractEventConverter):
)
class DingTalkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
class DingTalkAdapter(adapter.MessagePlatformAdapter):
bot: DingTalkClient
ap: app.Application
bot_account_id: str
message_converter: DingTalkMessageConverter = DingTalkMessageConverter()
event_converter: DingTalkEventConverter = DingTalkEventConverter()
config: dict
card_instance_id_dict: (
dict # 回复卡片消息字典key为消息idvalue为回复卡片实例id用于在流式消息时判断是否发送到指定卡片
)
card_instance_id_dict: dict # 回复卡片消息字典key为消息idvalue为回复卡片实例id用于在流式消息时判断是否发送到指定卡片
seq: int # 消息顺序直接以seq作为标识
def __init__(self, config: dict, logger: EventLogger):
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
self.config = config
self.ap = ap
self.logger = logger
self.card_instance_id_dict = {}
# self.seq = 1
@@ -159,11 +165,15 @@ class DingTalkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
msg_seq = bot_message.msg_sequence
if (msg_seq - 1) % 8 == 0 or is_final:
content, at = await DingTalkMessageConverter.yiri2target(message)
card_instance, card_instance_id = self.card_instance_id_dict[message_id]
if not content and bot_message.content:
content = bot_message.content # 兼容直接传入content的情况
# print(card_instance_id)
await self.bot.send_card_message(card_instance, card_instance_id, content, is_final)
if content:
await self.bot.send_card_message(card_instance, card_instance_id, content, is_final)
if is_final and bot_message.tool_calls is None:
# self.seq = 1 # 消息回复结束之后重置seq
self.card_instance_id_dict.pop(message_id) # 消息回复结束之后删除卡片实例id
@@ -192,9 +202,7 @@ class DingTalkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
async def on_message(event: DingTalkEvent):
try:
@@ -219,8 +227,6 @@ class DingTalkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
async def unregister_listener(
self,
event_type: type,
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
return super().unregister_listener(event_type, callback)

View File

@@ -12,14 +12,13 @@ import asyncio
from enum import Enum
import aiohttp
import pydantic
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
from .. import adapter
from ...core import app
from ..logger import EventLogger
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities
# 语音功能相关异常定义
@@ -583,7 +582,7 @@ class VoiceConnectionManager:
await self.stop_monitoring()
class DiscordMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
class DiscordMessageConverter(adapter.MessageConverter):
@staticmethod
async def yiri2target(
message_chain: platform_message.MessageChain,
@@ -737,7 +736,7 @@ class DiscordMessageConverter(abstract_platform_adapter.AbstractMessageConverter
return platform_message.MessageChain(element_list)
class DiscordEventConverter(abstract_platform_adapter.AbstractEventConverter):
class DiscordEventConverter(adapter.EventConverter):
@staticmethod
async def yiri2target(event: platform_events.Event) -> discord.Message:
pass
@@ -779,26 +778,32 @@ class DiscordEventConverter(abstract_platform_adapter.AbstractEventConverter):
)
class DiscordAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: discord.Client = pydantic.Field(exclude=True)
class DiscordAdapter(adapter.MessagePlatformAdapter):
bot: discord.Client
bot_account_id: str # 用于在流水线中识别at是否是本bot直接以bot_name作为标识
config: dict
ap: app.Application
message_converter: DiscordMessageConverter = DiscordMessageConverter()
event_converter: DiscordEventConverter = DiscordEventConverter()
listeners: typing.Dict[
typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
] = {}
voice_manager: VoiceConnectionManager | None = pydantic.Field(exclude=True, default=None)
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
self.config = config
self.ap = ap
self.logger = logger
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
bot_account_id = config['client_id']
listeners = {}
self.bot_account_id = self.config['client_id']
# 初始化语音连接管理器
# self.voice_manager: VoiceConnectionManager = None
self.voice_manager: VoiceConnectionManager = None
adapter_self = self
@@ -818,17 +823,7 @@ class DiscordAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
if os.getenv('http_proxy'):
args['proxy'] = os.getenv('http_proxy')
bot = MyClient(intents=intents, **args)
super().__init__(
config=config,
logger=logger,
bot_account_id=bot_account_id,
listeners=listeners,
bot=bot,
voice_manager=None,
**kwargs,
)
self.bot = MyClient(intents=intents, **args)
# Voice functionality methods
async def join_voice_channel(self, guild_id: int, channel_id: int, user_id: int = None) -> discord.VoiceClient:
@@ -1034,14 +1029,7 @@ class DiscordAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
if quote_origin:
args['reference'] = message_source.source_platform_object
has_at = False
for component in message.root:
if isinstance(component, platform_message.At):
has_at = True
break
if has_at:
if message.has(platform_message.At):
args['mention_author'] = True
await message_source.source_platform_object.channel.send(**args)
@@ -1052,18 +1040,14 @@ class DiscordAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners[event_type] = callback
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners.pop(event_type)

View File

Before

Width:  |  Height:  |  Size: 25 KiB

After

Width:  |  Height:  |  Size: 25 KiB

View File

@@ -11,19 +11,19 @@ import threading
import quart
import aiohttp
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from ....core import app
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from ....utils import image
from .. import adapter
from ...core import app
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities
from ...utils import image
import xml.etree.ElementTree as ET
from typing import Optional, Tuple
from functools import partial
from ...logger import EventLogger
from ..logger import EventLogger
class GewechatMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
class GewechatMessageConverter(adapter.MessageConverter):
def __init__(self, config: dict):
self.config = config
@@ -398,7 +398,7 @@ class GewechatMessageConverter(abstract_platform_adapter.AbstractMessageConverte
return from_user_name.endswith('@chatroom')
class GewechatEventConverter(abstract_platform_adapter.AbstractEventConverter):
class GewechatEventConverter(adapter.EventConverter):
def __init__(self, config: dict):
self.config = config
self.message_converter = GewechatMessageConverter(config)
@@ -458,7 +458,7 @@ class GewechatEventConverter(abstract_platform_adapter.AbstractEventConverter):
)
class GeWeChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
class GeWeChatAdapter(adapter.MessagePlatformAdapter):
name: str = 'gewechat' # 定义适配器名称
bot: gewechat_client.GewechatClient
@@ -475,7 +475,7 @@ class GeWeChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
listeners: typing.Dict[
typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
] = {}
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
@@ -491,7 +491,7 @@ class GeWeChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
async def gewechat_callback():
data = await quart.request.json
# print(json.dumps(data, indent=4, ensure_ascii=False))
await self.logger.debug(f'Gewechat callback event: {data}')
self.ap.logger.debug(f'Gewechat callback event: {data}')
if 'data' in data:
data['Data'] = data['data']
@@ -601,7 +601,7 @@ class GeWeChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
if handler := handler_map.get(msg['type']):
handler(msg)
else:
await self.logger.warning(f'未处理的消息类型: {msg["type"]}')
self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}')
continue
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
@@ -625,18 +625,14 @@ class GeWeChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners[event_type] = callback
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
pass
@@ -660,7 +656,9 @@ class GeWeChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
self.config['app_id'] = app_id
print(f'Gewechat 登录成功app_id: {app_id}')
self.ap.logger.info(f'Gewechat 登录成功app_id: {app_id}')
self.ap.platform_mgr.write_back_config('gewechat', self, self.config)
# 获取 nickname
profile = self.bot.get_profile(self.config['app_id'])

View File

@@ -17,14 +17,14 @@ import aiohttp
import lark_oapi.ws.exception
import quart
from lark_oapi.api.im.v1 import *
import pydantic
from lark_oapi.api.cardkit.v1 import *
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
from .. import adapter
from ...core import app
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities
from ..logger import EventLogger
class AESCipher(object):
@@ -53,7 +53,7 @@ class AESCipher(object):
return self.decrypt(enc).decode('utf8')
class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
class LarkMessageConverter(adapter.MessageConverter):
@staticmethod
async def yiri2target(
message_chain: platform_message.MessageChain, api_client: lark_oapi.Client
@@ -277,7 +277,7 @@ class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
return platform_message.MessageChain(lb_msg_list)
class LarkEventConverter(abstract_platform_adapter.AbstractEventConverter):
class LarkEventConverter(adapter.EventConverter):
@staticmethod
async def yiri2target(
event: platform_events.MessageEvent,
@@ -325,37 +325,49 @@ CARD_ID_CACHE_SIZE = 500
CARD_ID_CACHE_MAX_LIFETIME = 20 * 60 # 20分钟
class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: lark_oapi.ws.Client = pydantic.Field(exclude=True)
api_client: lark_oapi.Client = pydantic.Field(exclude=True)
class LarkAdapter(adapter.MessagePlatformAdapter):
bot: lark_oapi.ws.Client
api_client: lark_oapi.Client
bot_account_id: str # 用于在流水线中识别at是否是本bot直接以bot_name作为标识
lark_tenant_key: str = pydantic.Field(exclude=True, default='') # 飞书企业key
lark_tenant_key: str # 飞书企业key
message_converter: LarkMessageConverter = LarkMessageConverter()
event_converter: LarkEventConverter = LarkEventConverter()
listeners: typing.Dict[
typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
]
quart_app: quart.Quart = pydantic.Field(exclude=True)
config: dict
quart_app: quart.Quart
ap: app.Application
card_id_dict: dict[str, str] # 消息id到卡片id的映射便于创建卡片后的发送消息到指定卡片
seq: int # 用于在发送卡片消息中识别消息顺序直接以seq作为标识
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
quart_app = quart.Quart(__name__)
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
self.config = config
self.ap = ap
self.logger = logger
self.quart_app = quart.Quart(__name__)
self.listeners = {}
self.card_id_dict = {}
self.seq = 1
@quart_app.route('/lark/callback', methods=['POST'])
@self.quart_app.route('/lark/callback', methods=['POST'])
async def lark_callback():
try:
data = await quart.request.json
self.ap.logger.debug(f'Lark callback event: {data}')
if 'encrypt' in data:
cipher = AESCipher(config['encrypt-key'])
cipher = AESCipher(self.config['encrypt-key'])
data = cipher.decrypt_string(data['encrypt'])
data = json.loads(data)
@@ -402,24 +414,10 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
lark_oapi.EventDispatcherHandler.builder('', '').register_p2_im_message_receive_v1(sync_on_message).build()
)
bot_account_id = config['bot_name']
self.bot_account_id = config['bot_name']
bot = lark_oapi.ws.Client(config['app_id'], config['app_secret'], event_handler=event_handler)
api_client = lark_oapi.Client.builder().app_id(config['app_id']).app_secret(config['app_secret']).build()
super().__init__(
config=config,
logger=logger,
lark_tenant_key=config.get('lark_tenant_key', ''),
card_id_dict={},
seq=1,
listeners={},
quart_app=quart_app,
bot=bot,
api_client=api_client,
bot_account_id=bot_account_id,
**kwargs,
)
self.bot = lark_oapi.ws.Client(config['app_id'], config['app_secret'], event_handler=event_handler)
self.api_client = lark_oapi.Client.builder().app_id(config['app_id']).app_secret(config['app_secret']).build()
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
pass
@@ -432,177 +430,151 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
async def create_card_id(self, message_id):
try:
# self.logger.debug('飞书支持stream输出,创建卡片......')
self.ap.logger.debug('飞书支持stream输出,创建卡片......')
card_data = {
'schema': '2.0',
'config': {
'update_multi': True,
'streaming_mode': True,
'streaming_config': {
'print_step': {'default': 1},
'print_frequency_ms': {'default': 70},
'print_strategy': 'fast',
},
},
'body': {
'direction': 'vertical',
'padding': '12px 12px 12px 12px',
'elements': [
{
'tag': 'div',
'text': {
'tag': 'plain_text',
'content': 'LangBot',
'text_size': 'normal',
'text_align': 'left',
'text_color': 'default',
},
'icon': {
'tag': 'custom_icon',
'img_key': 'img_v3_02p3_05c65d5d-9bad-440a-a2fb-c89571bfd5bg',
},
},
{
'tag': 'markdown',
'content': '',
'text_align': 'left',
'text_size': 'normal',
'margin': '0px 0px 0px 0px',
'element_id': 'streaming_txt',
},
{
'tag': 'markdown',
'content': '',
'text_align': 'left',
'text_size': 'normal',
'margin': '0px 0px 0px 0px',
},
{
'tag': 'column_set',
'horizontal_spacing': '8px',
'horizontal_align': 'left',
'columns': [
{
'tag': 'column',
'width': 'weighted',
'elements': [
{
'tag': 'markdown',
'content': '',
'text_align': 'left',
'text_size': 'normal',
'margin': '0px 0px 0px 0px',
},
{
'tag': 'markdown',
'content': '',
'text_align': 'left',
'text_size': 'normal',
'margin': '0px 0px 0px 0px',
},
{
'tag': 'markdown',
'content': '',
'text_align': 'left',
'text_size': 'normal',
'margin': '0px 0px 0px 0px',
},
],
'padding': '0px 0px 0px 0px',
'direction': 'vertical',
'horizontal_spacing': '8px',
'vertical_spacing': '2px',
'horizontal_align': 'left',
'vertical_align': 'top',
'margin': '0px 0px 0px 0px',
'weight': 1,
}
],
'margin': '0px 0px 0px 0px',
},
{'tag': 'hr', 'margin': '0px 0px 0px 0px'},
{
'tag': 'column_set',
'horizontal_spacing': '12px',
'horizontal_align': 'right',
'columns': [
{
'tag': 'column',
'width': 'weighted',
'elements': [
{
'tag': 'markdown',
'content': '<font color="grey-600">以上内容由 AI 生成,仅供参考。更多详细、准确信息可点击引用链接查看</font>',
'text_align': 'left',
'text_size': 'notation',
'margin': '4px 0px 0px 0px',
'icon': {
'tag': 'standard_icon',
'token': 'robot_outlined',
'color': 'grey',
},
}
],
'padding': '0px 0px 0px 0px',
'direction': 'vertical',
'horizontal_spacing': '8px',
'vertical_spacing': '8px',
'horizontal_align': 'left',
'vertical_align': 'top',
'margin': '0px 0px 0px 0px',
'weight': 1,
},
{
'tag': 'column',
'width': '20px',
'elements': [
{
'tag': 'button',
'text': {'tag': 'plain_text', 'content': ''},
'type': 'text',
'width': 'fill',
'size': 'medium',
'icon': {'tag': 'standard_icon', 'token': 'thumbsup_outlined'},
'hover_tips': {'tag': 'plain_text', 'content': '有帮助'},
'margin': '0px 0px 0px 0px',
}
],
'padding': '0px 0px 0px 0px',
'direction': 'vertical',
'horizontal_spacing': '8px',
'vertical_spacing': '8px',
'horizontal_align': 'left',
'vertical_align': 'top',
'margin': '0px 0px 0px 0px',
},
{
'tag': 'column',
'width': '30px',
'elements': [
{
'tag': 'button',
'text': {'tag': 'plain_text', 'content': ''},
'type': 'text',
'width': 'default',
'size': 'medium',
'icon': {'tag': 'standard_icon', 'token': 'thumbdown_outlined'},
'hover_tips': {'tag': 'plain_text', 'content': '无帮助'},
'margin': '0px 0px 0px 0px',
}
],
'padding': '0px 0px 0px 0px',
'vertical_spacing': '8px',
'horizontal_align': 'left',
'vertical_align': 'top',
'margin': '0px 0px 0px 0px',
},
],
'margin': '0px 0px 4px 0px',
},
],
},
}
card_data = {"schema": "2.0", "config": {"update_multi": True, "streaming_mode": True,
"streaming_config": {"print_step": {"default": 1},
"print_frequency_ms": {"default": 70},
"print_strategy": "fast"}},
"body": {"direction": "vertical", "padding": "12px 12px 12px 12px", "elements": [{"tag": "div",
"text": {
"tag": "plain_text",
"content": "LangBot",
"text_size": "normal",
"text_align": "left",
"text_color": "default"},
"icon": {
"tag": "custom_icon",
"img_key": "img_v3_02p3_05c65d5d-9bad-440a-a2fb-c89571bfd5bg"}},
{
"tag": "markdown",
"content": "",
"text_align": "left",
"text_size": "normal",
"margin": "0px 0px 0px 0px",
"element_id": "streaming_txt"},
{
"tag": "markdown",
"content": "",
"text_align": "left",
"text_size": "normal",
"margin": "0px 0px 0px 0px"},
{
"tag": "column_set",
"horizontal_spacing": "8px",
"horizontal_align": "left",
"columns": [
{
"tag": "column",
"width": "weighted",
"elements": [
{
"tag": "markdown",
"content": "",
"text_align": "left",
"text_size": "normal",
"margin": "0px 0px 0px 0px"},
{
"tag": "markdown",
"content": "",
"text_align": "left",
"text_size": "normal",
"margin": "0px 0px 0px 0px"},
{
"tag": "markdown",
"content": "",
"text_align": "left",
"text_size": "normal",
"margin": "0px 0px 0px 0px"}],
"padding": "0px 0px 0px 0px",
"direction": "vertical",
"horizontal_spacing": "8px",
"vertical_spacing": "2px",
"horizontal_align": "left",
"vertical_align": "top",
"margin": "0px 0px 0px 0px",
"weight": 1}],
"margin": "0px 0px 0px 0px"},
{"tag": "hr",
"margin": "0px 0px 0px 0px"},
{
"tag": "column_set",
"horizontal_spacing": "12px",
"horizontal_align": "right",
"columns": [
{
"tag": "column",
"width": "weighted",
"elements": [
{
"tag": "markdown",
"content": "<font color=\"grey-600\">以上内容由 AI 生成,仅供参考。更多详细、准确信息可点击引用链接查看</font>",
"text_align": "left",
"text_size": "notation",
"margin": "4px 0px 0px 0px",
"icon": {
"tag": "standard_icon",
"token": "robot_outlined",
"color": "grey"}}],
"padding": "0px 0px 0px 0px",
"direction": "vertical",
"horizontal_spacing": "8px",
"vertical_spacing": "8px",
"horizontal_align": "left",
"vertical_align": "top",
"margin": "0px 0px 0px 0px",
"weight": 1},
{
"tag": "column",
"width": "20px",
"elements": [
{
"tag": "button",
"text": {
"tag": "plain_text",
"content": ""},
"type": "text",
"width": "fill",
"size": "medium",
"icon": {
"tag": "standard_icon",
"token": "thumbsup_outlined"},
"hover_tips": {
"tag": "plain_text",
"content": "有帮助"},
"margin": "0px 0px 0px 0px"}],
"padding": "0px 0px 0px 0px",
"direction": "vertical",
"horizontal_spacing": "8px",
"vertical_spacing": "8px",
"horizontal_align": "left",
"vertical_align": "top",
"margin": "0px 0px 0px 0px"},
{
"tag": "column",
"width": "30px",
"elements": [
{
"tag": "button",
"text": {
"tag": "plain_text",
"content": ""},
"type": "text",
"width": "default",
"size": "medium",
"icon": {
"tag": "standard_icon",
"token": "thumbdown_outlined"},
"hover_tips": {
"tag": "plain_text",
"content": "无帮助"},
"margin": "0px 0px 0px 0px"}],
"padding": "0px 0px 0px 0px",
"vertical_spacing": "8px",
"horizontal_align": "left",
"vertical_align": "top",
"margin": "0px 0px 0px 0px"}],
"margin": "0px 0px 4px 0px"}]}}
# delay / fast 创建卡片模板delay 延迟打印fast 实时打印,可以自定义更好看的消息模板
request: CreateCardRequest = (
@@ -640,7 +612,7 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
content = {
'type': 'card',
'data': {'card_id': card_id, 'template_variable': {'content': 'Thinking...'}},
} # 当收到消息时发送消息模板,可添加模板变量,详情查看飞书中接口文档
} # 当收到消息时发送消息模板,可添加模板变量,详情查看飞书中接口文档
request: ReplyMessageRequest = (
ReplyMessageRequest.builder()
.message_id(event.message_chain.message_id)
@@ -713,8 +685,10 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
message_id = bot_message.resp_message_id
msg_seq = bot_message.msg_sequence
if msg_seq % 8 == 0 or is_final:
lark_message = await self.message_converter.yiri2target(message, self.api_client)
text_message = ''
for ele in lark_message[0]:
if ele['tag'] == 'text':
@@ -760,18 +734,14 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners[event_type] = callback
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners.pop(event_type)
@@ -808,4 +778,4 @@ class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
# 所以要设置_auto_reconnect=False,让其不重连。
self.bot._auto_reconnect = False
await self.bot._disconnect()
return False
return False

View File

Before

Width:  |  Height:  |  Size: 274 KiB

After

Width:  |  Height:  |  Size: 274 KiB

View File

@@ -9,15 +9,15 @@ import traceback
import nakuru
import nakuru.entities.components as nkc
from ....pipeline.longtext.strategies import forward
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from ...logger import EventLogger
from .. import adapter as adapter_model
from ...pipeline.longtext.strategies import forward
from ...platform.types import message as platform_message
from ...platform.types import entities as platform_entities
from ...platform.types import events as platform_events
from ..logger import EventLogger
class NakuruProjectMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
class NakuruProjectMessageConverter(adapter_model.MessageConverter):
"""消息转换器"""
@staticmethod
@@ -109,7 +109,7 @@ class NakuruProjectMessageConverter(abstract_platform_adapter.AbstractMessageCon
return chain
class NakuruProjectEventConverter(abstract_platform_adapter.AbstractEventConverter):
class NakuruProjectEventConverter(adapter_model.EventConverter):
"""事件转换器"""
@staticmethod
@@ -164,7 +164,7 @@ class NakuruProjectEventConverter(abstract_platform_adapter.AbstractEventConvert
raise Exception('未支持转换的事件类型: ' + str(event))
class NakuruAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
class NakuruAdapter(adapter_model.MessagePlatformAdapter):
"""nakuru-project适配器"""
bot: nakuru.CQHTTP
@@ -256,15 +256,13 @@ class NakuruAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
):
try:
source_cls = NakuruProjectEventConverter.yiri2target(event_type)
# 包装函数
async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls): # type: ignore
async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls):
await callback(self.event_converter.target2yiri(source), self)
# 将包装函数和原函数的对应关系存入列表
@@ -285,9 +283,7 @@ class NakuruAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
):
nakuru_event_name = self.event_converter.yiri2target(event_type).__name__
@@ -326,6 +322,7 @@ class NakuruAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
except Exception:
raise Exception('获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确')
await self.bot._run()
self.ap.logger.info('运行 Nakuru 适配器')
while True:
await asyncio.sleep(1)

View File

@@ -4,18 +4,19 @@ import asyncio
import traceback
import datetime
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from pkg.platform.adapter import MessagePlatformAdapter
from pkg.platform.types import events as platform_events, message as platform_message
from libs.official_account_api.oaevent import OAEvent
from libs.official_account_api.api import OAClient
from libs.official_account_api.api import OAClientForLongerResponse
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
from langbot_plugin.api.entities.builtin.command import errors as command_errors
from .. import adapter
from ...core import app
from ..types import entities as platform_entities
from ...command.errors import ParamNotEnoughError
from ..logger import EventLogger
class OAMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
class OAMessageConverter(adapter.MessageConverter):
@staticmethod
async def yiri2target(message_chain: platform_message.MessageChain):
for msg in message_chain:
@@ -33,7 +34,7 @@ class OAMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
return chain
class OAEventConverter(abstract_platform_adapter.AbstractEventConverter):
class OAEventConverter(adapter.EventConverter):
@staticmethod
async def target2yiri(event: OAEvent):
if event.type == 'text':
@@ -55,15 +56,17 @@ class OAEventConverter(abstract_platform_adapter.AbstractEventConverter):
return None
class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
bot: OAClient | OAClientForLongerResponse
ap: app.Application
bot_account_id: str
message_converter: OAMessageConverter = OAMessageConverter()
event_converter: OAEventConverter = OAEventConverter()
config: dict
def __init__(self, config: dict, logger: EventLogger):
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
self.config = config
self.ap = ap
self.logger = logger
required_keys = [
@@ -75,7 +78,7 @@ class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAd
]
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise command_errors.ParamNotEnoughError('微信公众号缺少相关配置项,请查看文档或联系管理员')
raise ParamNotEnoughError('微信公众号缺少相关配置项,请查看文档或联系管理员')
if self.config['Mode'] == 'drop':
self.bot = OAClient(
@@ -116,9 +119,7 @@ class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAd
def register_listener(
self,
event_type: type,
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
async def on_message(event: OAEvent):
self.bot_account_id = event.receiver_id
@@ -149,8 +150,6 @@ class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAd
async def unregister_listener(
self,
event_type: type,
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
return super().unregister_listener(event_type, callback)

View File

@@ -10,14 +10,14 @@ import botpy
import botpy.message as botpy_message
import botpy.types.message as botpy_message_type
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from ....pipeline.longtext.strategies import forward
from ....core import app
from ....config import manager as cfg_mgr
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ...logger import EventLogger
from .. import adapter as adapter_model
from ...pipeline.longtext.strategies import forward
from ...core import app
from ...config import manager as cfg_mgr
from ...platform.types import entities as platform_entities
from ...platform.types import events as platform_events
from ...platform.types import message as platform_message
from ..logger import EventLogger
class OfficialGroupMessage(platform_events.GroupMessage):
@@ -133,7 +133,7 @@ class OpenIDMapping(typing.Generic[K, V]):
return value
class OfficialMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
class OfficialMessageConverter(adapter_model.MessageConverter):
"""QQ 官方消息转换器"""
@staticmethod
@@ -237,7 +237,7 @@ class OfficialMessageConverter(abstract_platform_adapter.AbstractMessageConverte
return chain
class OfficialEventConverter(abstract_platform_adapter.AbstractEventConverter):
class OfficialEventConverter(adapter_model.EventConverter):
"""事件转换器"""
def __init__(self):
@@ -333,7 +333,7 @@ class OfficialEventConverter(abstract_platform_adapter.AbstractEventConverter):
)
class OfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
class OfficialAdapter(adapter_model.MessagePlatformAdapter):
"""QQ 官方消息适配器"""
bot: botpy.Client = None
@@ -484,9 +484,7 @@ class OfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
):
try:
@@ -509,9 +507,7 @@ class OfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
):
delattr(self.bot, event_handler_mapping[event_type])
@@ -523,7 +519,7 @@ class OfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
self.cfg['ret_coro'] = True
await self.logger.info('运行 QQ 官方适配器')
self.ap.logger.info('运行 QQ 官方适配器')
await (await self.bot.start(**self.cfg))
async def kill(self) -> bool:

View File

Before

Width:  |  Height:  |  Size: 1.7 KiB

After

Width:  |  Height:  |  Size: 1.7 KiB

View File

@@ -5,18 +5,19 @@ import traceback
import datetime
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from langbot_plugin.api.entities.builtin.command import errors as command_errors
from pkg.platform.adapter import MessagePlatformAdapter
from pkg.platform.types import events as platform_events, message as platform_message
from .. import adapter
from ...core import app
from ..types import entities as platform_entities
from ...command.errors import ParamNotEnoughError
from libs.qq_official_api.api import QQOfficialClient
from libs.qq_official_api.qqofficialevent import QQOfficialEvent
from ...utils import image
from ..logger import EventLogger
class QQOfficialMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
class QQOfficialMessageConverter(adapter.MessageConverter):
@staticmethod
async def yiri2target(message_chain: platform_message.MessageChain):
content_list = []
@@ -45,7 +46,7 @@ class QQOfficialMessageConverter(abstract_platform_adapter.AbstractMessageConver
return chain
class QQOfficialEventConverter(abstract_platform_adapter.AbstractEventConverter):
class QQOfficialEventConverter(adapter.EventConverter):
@staticmethod
async def yiri2target(event: platform_events.MessageEvent) -> QQOfficialEvent:
return event.source_platform_object
@@ -131,15 +132,17 @@ class QQOfficialEventConverter(abstract_platform_adapter.AbstractEventConverter)
)
class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
class QQOfficialAdapter(adapter.MessagePlatformAdapter):
bot: QQOfficialClient
ap: app.Application
config: dict
bot_account_id: str
message_converter: QQOfficialMessageConverter = QQOfficialMessageConverter()
event_converter: QQOfficialEventConverter = QQOfficialEventConverter()
def __init__(self, config: dict, logger: EventLogger):
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
self.config = config
self.ap = ap
self.logger = logger
required_keys = [
@@ -148,7 +151,7 @@ class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
]
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise command_errors.ParamNotEnoughError('QQ官方机器人缺少相关配置项请查看文档或联系管理员')
raise ParamNotEnoughError('QQ官方机器人缺少相关配置项请查看文档或联系管理员')
self.bot = QQOfficialClient(
app_id=config['appid'], secret=config['secret'], token=config['token'], logger=self.logger
@@ -212,9 +215,7 @@ class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
async def on_message(event: QQOfficialEvent):
self.bot_account_id = 'justbot'
@@ -247,8 +248,6 @@ class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
def unregister_listener(
self,
event_type: type,
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
return super().unregister_listener(event_type, callback)

View File

@@ -6,17 +6,18 @@ import traceback
import datetime
from libs.slack_api.api import SlackClient
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from pkg.platform.adapter import MessagePlatformAdapter
from pkg.platform.types import events as platform_events, message as platform_message
from libs.slack_api.slackevent import SlackEvent
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from langbot_plugin.api.entities.builtin.command import errors as command_errors
from pkg.core import app
from .. import adapter
from ..types import entities as platform_entities
from ...command.errors import ParamNotEnoughError
from ...utils import image
from ..logger import EventLogger
class SlackMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
class SlackMessageConverter(adapter.MessageConverter):
@staticmethod
async def yiri2target(message_chain: platform_message.MessageChain):
content_list = []
@@ -43,7 +44,7 @@ class SlackMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
return chain
class SlackEventConverter(abstract_platform_adapter.AbstractEventConverter):
class SlackEventConverter(adapter.EventConverter):
@staticmethod
async def yiri2target(event: platform_events.MessageEvent) -> SlackEvent:
return event.source_platform_object
@@ -83,15 +84,17 @@ class SlackEventConverter(abstract_platform_adapter.AbstractEventConverter):
)
class SlackAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
class SlackAdapter(adapter.MessagePlatformAdapter):
bot: SlackClient
ap: app.Application
bot_account_id: str
message_converter: SlackMessageConverter = SlackMessageConverter()
event_converter: SlackEventConverter = SlackEventConverter()
config: dict
def __init__(self, config: dict, logger: EventLogger):
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
self.config = config
self.ap = ap
self.logger = logger
required_keys = [
'bot_token',
@@ -99,7 +102,7 @@ class SlackAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
]
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise command_errors.ParamNotEnoughError('Slack机器人缺少相关配置项请查看文档或联系管理员')
raise ParamNotEnoughError('Slack机器人缺少相关配置项请查看文档或联系管理员')
self.bot = SlackClient(
bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger
@@ -132,9 +135,7 @@ class SlackAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
async def on_message(event: SlackEvent):
self.bot_account_id = 'SlackBot'
@@ -165,8 +166,6 @@ class SlackAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
async def unregister_listener(
self,
event_type: type,
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
):
return super().unregister_listener(event_type, callback)

View File

@@ -10,16 +10,18 @@ import typing
import traceback
import base64
import aiohttp
import pydantic
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
from lark_oapi.api.im.v1 import *
from .. import adapter
from ...core import app
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities
from ..logger import EventLogger
class TelegramMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
class TelegramMessageConverter(adapter.MessageConverter):
@staticmethod
async def yiri2target(message_chain: platform_message.MessageChain, bot: telegram.Bot) -> list[dict]:
components = []
@@ -88,7 +90,7 @@ class TelegramMessageConverter(abstract_platform_adapter.AbstractMessageConverte
return platform_message.MessageChain(message_components)
class TelegramEventConverter(abstract_platform_adapter.AbstractEventConverter):
class TelegramEventConverter(adapter.EventConverter):
@staticmethod
async def yiri2target(event: platform_events.MessageEvent, bot: telegram.Bot):
return event.source_platform_object
@@ -130,14 +132,17 @@ class TelegramEventConverter(abstract_platform_adapter.AbstractEventConverter):
)
class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: telegram.Bot = pydantic.Field(exclude=True)
application: telegram.ext.Application = pydantic.Field(exclude=True)
class TelegramAdapter(adapter.MessagePlatformAdapter):
bot: telegram.Bot
application: telegram.ext.Application
bot_account_id: str
message_converter: TelegramMessageConverter = TelegramMessageConverter()
event_converter: TelegramEventConverter = TelegramEventConverter()
config: dict
ap: app.Application
msg_stream_id: dict # 流式消息id字典key为流式消息idvalue为首次消息源id用于在流式消息时判断编辑那条消息
@@ -145,10 +150,16 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
listeners: typing.Dict[
typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
] = {}
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
self.config = config
self.ap = ap
self.logger = logger
self.msg_stream_id = {}
# self.seq = 1
async def telegram_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
if update.message.from_user.is_bot:
return
@@ -160,18 +171,10 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
except Exception:
await self.logger.error(f'Error in telegram callback: {traceback.format_exc()}')
application = ApplicationBuilder().token(config['token']).build()
bot = application.bot
application.add_handler(MessageHandler(filters.TEXT | (filters.COMMAND) | filters.PHOTO, telegram_callback))
super().__init__(
config=config,
logger=logger,
msg_stream_id={},
seq=1,
bot=bot,
application=application,
bot_account_id='',
listeners={},
self.application = ApplicationBuilder().token(self.config['token']).build()
self.bot = self.application.bot
self.application.add_handler(
MessageHandler(filters.TEXT | (filters.COMMAND) | filters.PHOTO, telegram_callback)
)
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
@@ -275,18 +278,14 @@ class TelegramAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners[event_type] = callback
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners.pop(event_type)

View File

@@ -3,19 +3,17 @@ import logging
import typing
from datetime import datetime
import pydantic
from pydantic import BaseModel
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
from .. import adapter as msadapter
from ..types import events as platform_events, message as platform_message, entities as platform_entities
from ...core import app
from ..logger import EventLogger
logger = logging.getLogger(__name__)
class WebChatMessage(pydantic.BaseModel):
class WebChatMessage(BaseModel):
id: int
role: str
content: str
@@ -43,35 +41,30 @@ class WebChatSession:
return self.message_lists[pipeline_uuid]
class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
class WebChatAdapter(msadapter.MessagePlatformAdapter):
"""WebChat调试适配器用于流水线调试"""
webchat_person_session: WebChatSession = pydantic.Field(exclude=True, default_factory=WebChatSession)
webchat_group_session: WebChatSession = pydantic.Field(exclude=True, default_factory=WebChatSession)
webchat_person_session: WebChatSession
webchat_group_session: WebChatSession
listeners: dict[
listeners: typing.Dict[
typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
] = pydantic.Field(default_factory=dict, exclude=True)
typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], None],
] = {}
is_stream: bool = pydantic.Field(exclude=True)
debug_messages: dict[str, list[dict]] = pydantic.Field(default_factory=dict, exclude=True)
is_stream: bool
ap: app.Application = pydantic.Field(exclude=True)
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
super().__init__(
config=config,
logger=logger,
**kwargs,
)
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
self.ap = ap
self.logger = logger
self.config = config
self.webchat_person_session = WebChatSession(id='webchatperson')
self.webchat_group_session = WebChatSession(id='webchatgroup')
self.bot_account_id = 'webchatbot'
self.debug_messages = {}
self.is_stream = False
async def send_message(
self,
@@ -166,9 +159,7 @@ class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
func: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
],
func: typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], typing.Awaitable[None]],
):
"""注册事件监听器"""
self.listeners[event_type] = func
@@ -176,16 +167,11 @@ class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
func: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], typing.Awaitable[None]
],
func: typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], typing.Awaitable[None]],
):
"""取消注册事件监听器"""
del self.listeners[event_type]
async def is_muted(self, group_id: int) -> bool:
return False
async def run_async(self):
"""运行适配器"""
await self.logger.info('WebChat调试适配器已启动')
@@ -235,7 +221,7 @@ class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
message_chain.insert(0, platform_message.Source(id=message_id, time=datetime.now().timestamp()))
if session_type == 'person':
sender = platform_entities.Friend(id='webchatperson', nickname='User', remark='User')
sender = platform_entities.Friend(id='webchatperson', nickname='User')
event = platform_events.FriendMessage(
sender=sender, message_chain=message_chain, time=datetime.now().timestamp()
)

Some files were not shown because too many files have changed in this diff Show More