feat: switch Query to langbot-plugin definition

This commit is contained in:
Junyan Qin
2025-06-15 22:04:31 +08:00
parent 0c2560cafb
commit 6b782f8761
88 changed files with 248 additions and 348 deletions

10
main.py
View File

@@ -47,13 +47,13 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
if not args.skip_plugin_deps_check: if not args.skip_plugin_deps_check:
await deps.precheck_plugin_deps() await deps.precheck_plugin_deps()
# 检查pydantic版本如果没有 pydantic.v1则把 pydantic 映射为 v1 # # 检查pydantic版本如果没有 pydantic.v1则把 pydantic 映射为 v1
import pydantic.version # import pydantic.version
if pydantic.version.VERSION < '2.0': # if pydantic.version.VERSION < '2.0':
import pydantic # import pydantic
sys.modules['pydantic.v1'] = pydantic # sys.modules['pydantic.v1'] = pydantic
# 检查配置文件 # 检查配置文件

View File

@@ -35,15 +35,6 @@ class SystemRouterGroup(group.RouterGroup):
return self.success(data=task.to_dict()) return self.success(data=task.to_dict())
@self.route('/reload', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
json_data = await quart.request.json
scope = json_data.get('scope')
await self.ap.reload(scope=scope)
return self.success()
@self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) @self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str: async def _() -> str:
if not constants.debug_mode: if not constants.debug_mode:

View File

@@ -2,10 +2,11 @@ from __future__ import annotations
import typing import typing
from ..core import app, entities as core_entities from ..core import app
from . import entities, operator, errors from . import entities, operator, errors
from ..utils import importutil from ..utils import importutil
import langbot_plugin.api.entities.builtin.provider.session as provider_session import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
# 引入所有算子以便注册 # 引入所有算子以便注册
from . import operators from . import operators
@@ -90,7 +91,7 @@ class CommandManager:
async def execute( async def execute(
self, self,
command_text: str, command_text: str,
query: core_entities.Query, query: pipeline_query.Query,
session: provider_session.Session, session: provider_session.Session,
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令""" """执行命令"""

View File

@@ -2,12 +2,12 @@ from __future__ import annotations
import typing import typing
import pydantic.v1 as pydantic import pydantic
import langbot_plugin.api.entities.builtin.provider.session as provider_session import langbot_plugin.api.entities.builtin.provider.session as provider_session
from ..core import entities as core_entities
from . import errors from . import errors
from ..platform.types import message as platform_message from ..platform.types import message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class CommandReturn(pydantic.BaseModel): class CommandReturn(pydantic.BaseModel):
@@ -35,7 +35,7 @@ class CommandReturn(pydantic.BaseModel):
class ExecuteContext(pydantic.BaseModel): class ExecuteContext(pydantic.BaseModel):
"""单次命令执行上下文""" """单次命令执行上下文"""
query: core_entities.Query query: pipeline_query.Query
"""本次消息的请求对象""" """本次消息的请求对象"""
session: provider_session.Session session: provider_session.Session

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import logging import logging
import asyncio import asyncio
import traceback import traceback
import sys
import os import os
from ..platform import botmgr as im_mgr from ..platform import botmgr as im_mgr
@@ -183,59 +182,3 @@ class Application:
""".strip() """.strip()
for line in tips.split('\n'): for line in tips.split('\n'):
self.logger.info(line) self.logger.info(line)
async def reload(
self,
scope: core_entities.LifecycleControlScope,
):
match scope:
case core_entities.LifecycleControlScope.PLATFORM.value:
self.logger.info('执行热重载 scope=' + scope)
await self.platform_mgr.shutdown()
self.platform_mgr = im_mgr.PlatformManager(self)
await self.platform_mgr.initialize()
self.task_mgr.create_task(
self.platform_mgr.run(),
name='platform-manager',
scopes=[
core_entities.LifecycleControlScope.APPLICATION,
core_entities.LifecycleControlScope.PLATFORM,
],
)
case core_entities.LifecycleControlScope.PLUGIN.value:
self.logger.info('执行热重载 scope=' + scope)
await self.plugin_mgr.destroy_plugins()
# 删除 sys.module 中所有的 plugins/* 下的模块
for mod in list(sys.modules.keys()):
if mod.startswith('plugins.'):
del sys.modules[mod]
self.plugin_mgr = plugin_mgr.PluginManager(self)
await self.plugin_mgr.initialize()
await self.plugin_mgr.initialize_plugins()
await self.plugin_mgr.load_plugins()
await self.plugin_mgr.initialize_plugins()
case core_entities.LifecycleControlScope.PROVIDER.value:
self.logger.info('执行热重载 scope=' + scope)
await self.tool_mgr.shutdown()
llm_model_mgr_inst = llm_model_mgr.ModelManager(self)
await llm_model_mgr_inst.initialize()
self.model_mgr = llm_model_mgr_inst
llm_session_mgr_inst = llm_session_mgr.SessionManager(self)
await llm_session_mgr_inst.initialize()
self.sess_mgr = llm_session_mgr_inst
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(self)
await llm_tool_mgr_inst.initialize()
self.tool_mgr = llm_tool_mgr_inst
case _:
pass

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import enum import enum
import typing import typing
import pydantic.v1 as pydantic import pydantic
from ..provider import entities as llm_entities from ..provider import entities as llm_entities
from ..platform import adapter as msadapter from ..platform import adapter as msadapter
@@ -20,23 +20,13 @@ class LifecycleControlScope(enum.Enum):
PROVIDER = 'provider' PROVIDER = 'provider'
class LauncherTypes(enum.Enum):
"""一个请求的发起者类型"""
PERSON = 'person'
"""私聊"""
GROUP = 'group'
"""群聊"""
class Query(pydantic.BaseModel): class Query(pydantic.BaseModel):
"""一次请求的信息封装""" """一次请求的信息封装"""
query_id: int query_id: int
"""请求ID添加进请求池时生成""" """请求ID添加进请求池时生成"""
launcher_type: LauncherTypes launcher_type: provider_session.LauncherTypes
"""会话类型platform处理阶段设置""" """会话类型platform处理阶段设置"""
launcher_id: typing.Union[int, str] launcher_id: typing.Union[int, str]

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('BanSessionCheckStage') @stage.stage_class('BanSessionCheckStage')
@@ -14,7 +14,7 @@ class BanSessionCheckStage(stage.PipelineStage):
async def initialize(self, pipeline_config: dict): async def initialize(self, pipeline_config: dict):
pass pass
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
found = False found = False
mode = query.pipeline_config['trigger']['access-control']['mode'] mode = query.pipeline_config['trigger']['access-control']['mode']

View File

@@ -3,12 +3,11 @@ from __future__ import annotations
from ...core import app from ...core import app
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities
from . import filter as filter_model, entities as filter_entities from . import filter as filter_model, entities as filter_entities
from langbot_plugin.api.entities.builtin.provider import message as provider_message from langbot_plugin.api.entities.builtin.provider import message as provider_message
from ...platform.types import message as platform_message from ...platform.types import message as platform_message
from ...utils import importutil from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import filters from . import filters
importutil.import_modules_in_pkg(filters) importutil.import_modules_in_pkg(filters)
@@ -58,7 +57,7 @@ class ContentFilterStage(stage.PipelineStage):
async def _pre_process( async def _pre_process(
self, self,
message: str, message: str,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.StageProcessResult: ) -> entities.StageProcessResult:
"""请求llm前处理消息 """请求llm前处理消息
只要有一个不通过就不放行,只放行 PASS 的消息 只要有一个不通过就不放行,只放行 PASS 的消息
@@ -93,7 +92,7 @@ class ContentFilterStage(stage.PipelineStage):
async def _post_process( async def _post_process(
self, self,
message: str, message: str,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.StageProcessResult: ) -> entities.StageProcessResult:
"""请求llm后处理响应 """请求llm后处理响应
只要是 PASS 或者 MASKED 的就通过此 filter将其 replacement 设置为message进入下一个 filter 只要是 PASS 或者 MASKED 的就通过此 filter将其 replacement 设置为message进入下一个 filter
@@ -123,7 +122,7 @@ class ContentFilterStage(stage.PipelineStage):
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理""" """处理"""
if stage_inst_name == 'PreContentFilterStage': if stage_inst_name == 'PreContentFilterStage':
contain_non_text = False contain_non_text = False

View File

@@ -1,6 +1,6 @@
import enum import enum
import pydantic.v1 as pydantic import pydantic
class ResultLevel(enum.Enum): class ResultLevel(enum.Enum):

View File

@@ -3,9 +3,9 @@ from __future__ import annotations
import abc import abc
import typing import typing
from ...core import app, entities as core_entities from ...core import app
from . import entities from . import entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_filters: list[typing.Type[ContentFilter]] = [] preregistered_filters: list[typing.Type[ContentFilter]] = []
@@ -60,7 +60,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def process(self, query: core_entities.Query, message: str = None, image_url=None) -> entities.FilterResult: async def process(self, query: pipeline_query.Query, message: str = None, image_url=None) -> entities.FilterResult:
"""处理消息 """处理消息
分为前后阶段,具体取决于 enable_stages 的值。 分为前后阶段,具体取决于 enable_stages 的值。

View File

@@ -4,8 +4,7 @@ import aiohttp
from .. import entities from .. import entities
from .. import filter as filter_model from .. import filter as filter_model
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}' BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}'
BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token' BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
@@ -27,7 +26,7 @@ class BaiduCloudExamine(filter_model.ContentFilter):
) as resp: ) as resp:
return (await resp.json())['access_token'] return (await resp.json())['access_token']
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post( async with session.post(
BAIDU_EXAMINE_URL.format(await self._get_token()), BAIDU_EXAMINE_URL.format(await self._get_token()),

View File

@@ -3,7 +3,7 @@ import re
from .. import filter as filter_model from .. import filter as filter_model
from .. import entities from .. import entities
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@filter_model.filter_class('ban-word-filter') @filter_model.filter_class('ban-word-filter')
@@ -13,7 +13,7 @@ class BanWordFilter(filter_model.ContentFilter):
async def initialize(self): async def initialize(self):
pass pass
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
found = False found = False
for word in self.ap.sensitive_meta.data['words']: for word in self.ap.sensitive_meta.data['words']:

View File

@@ -3,7 +3,7 @@ import re
from .. import entities from .. import entities
from .. import filter as filter_model from .. import filter as filter_model
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@filter_model.filter_class('content-ignore') @filter_model.filter_class('content-ignore')
@@ -16,7 +16,7 @@ class ContentIgnore(filter_model.ContentFilter):
entities.EnableStage.PRE, entities.EnableStage.PRE,
] ]
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']: if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']: for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
if message.startswith(rule): if message.startswith(rule):

View File

@@ -3,7 +3,10 @@ from __future__ import annotations
import asyncio import asyncio
import traceback import traceback
from ..core import app, entities from ..core import app
from ..core import entities as core_entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class Controller: class Controller:
@@ -22,11 +25,11 @@ class Controller:
"""事件处理循环""" """事件处理循环"""
try: try:
while True: while True:
selected_query: entities.Query = None selected_query: pipeline_query.Query = None
# 取请求 # 取请求
async with self.ap.query_pool: async with self.ap.query_pool:
queries: list[entities.Query] = self.ap.query_pool.queries queries: list[pipeline_query.Query] = self.ap.query_pool.queries
for query in queries: for query in queries:
session = await self.ap.sess_mgr.get_session(query) session = await self.ap.sess_mgr.get_session(query)
@@ -46,7 +49,7 @@ class Controller:
if selected_query: if selected_query:
async def _process_query(selected_query: entities.Query): async def _process_query(selected_query: pipeline_query.Query):
async with self.semaphore: # 总并发上限 async with self.semaphore: # 总并发上限
# find pipeline # find pipeline
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one. # Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
@@ -68,8 +71,8 @@ class Controller:
kind='query', kind='query',
name=f'query-{selected_query.query_id}', name=f'query-{selected_query.query_id}',
scopes=[ scopes=[
entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.APPLICATION,
entities.LifecycleControlScope.PLATFORM, core_entities.LifecycleControlScope.PLATFORM,
], ],
) )

View File

@@ -3,10 +3,10 @@ from __future__ import annotations
import enum import enum
import typing import typing
import pydantic.v1 as pydantic import pydantic
from ..platform.types import message as platform_message from ..platform.types import message as platform_message
from ..core import entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class ResultType(enum.Enum): class ResultType(enum.Enum):
@@ -20,7 +20,7 @@ class ResultType(enum.Enum):
class StageProcessResult(pydantic.BaseModel): class StageProcessResult(pydantic.BaseModel):
result_type: ResultType result_type: ResultType
new_query: entities.Query new_query: pipeline_query.Query
user_notice: typing.Optional[ user_notice: typing.Optional[
typing.Union[ typing.Union[

View File

@@ -5,10 +5,9 @@ import traceback
from . import strategy from . import strategy
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities
from ...platform.types import message as platform_message from ...platform.types import message as platform_message
from ...utils import importutil from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import strategies from . import strategies
importutil.import_modules_in_pkg(strategies) importutil.import_modules_in_pkg(strategies)
@@ -67,7 +66,7 @@ class LongTextProcessStage(stage.PipelineStage):
await self.strategy_impl.initialize() await self.strategy_impl.initialize()
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
# 检查是否包含非 Plain 组件 # 检查是否包含非 Plain 组件
contains_non_plain = False contains_non_plain = False

View File

@@ -3,9 +3,9 @@ from __future__ import annotations
from .. import strategy as strategy_model from .. import strategy as strategy_model
from ....core import entities as core_entities
from ....platform.types import message as platform_message
from ....platform.types import message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay
Forward = platform_message.Forward Forward = platform_message.Forward
@@ -13,7 +13,7 @@ Forward = platform_message.Forward
@strategy_model.strategy_class('forward') @strategy_model.strategy_class('forward')
class ForwardComponentStrategy(strategy_model.LongTextStrategy): class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
display = ForwardMessageDiaplay( display = ForwardMessageDiaplay(
title='群聊的聊天记录', title='群聊的聊天记录',
brief='[聊天记录]', brief='[聊天记录]',

View File

@@ -11,7 +11,7 @@ import functools
from ....platform.types import message as platform_message from ....platform.types import message as platform_message
from .. import strategy as strategy_model from .. import strategy as strategy_model
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@strategy_model.strategy_class('image') @strategy_model.strategy_class('image')
@@ -27,7 +27,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
encoding='utf-8', encoding='utf-8',
) )
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
img_path = self.text_to_image( img_path = self.text_to_image(
text_str=message, text_str=message,
save_as='temp/{}.png'.format(int(time.time())), save_as='temp/{}.png'.format(int(time.time())),
@@ -131,7 +131,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
text_str: str, text_str: str,
save_as='temp.png', save_as='temp.png',
width=800, width=800,
query: core_entities.Query = None, query: pipeline_query.Query = None,
): ):
text_str = text_str.replace('\t', ' ') text_str = text_str.replace('\t', ' ')

View File

@@ -4,8 +4,8 @@ import typing
from ...core import app from ...core import app
from ...core import entities as core_entities
from ...platform.types import message as platform_message from ...platform.types import message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_strategies: list[typing.Type[LongTextStrategy]] = [] preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
@@ -49,7 +49,7 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
"""处理长文本 """处理长文本
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法 在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法

View File

@@ -1,10 +1,9 @@
from __future__ import annotations from __future__ import annotations
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities
from . import truncator from . import truncator
from ...utils import importutil from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import truncators from . import truncators
importutil.import_modules_in_pkg(truncators) importutil.import_modules_in_pkg(truncators)
@@ -29,7 +28,7 @@ class ConversationMessageTruncator(stage.PipelineStage):
else: else:
raise ValueError(f'未知的截断器: {use_method}') raise ValueError(f'未知的截断器: {use_method}')
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理""" """处理"""
query = await self.trun.truncate(query) query = await self.trun.truncate(query)

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
import typing import typing
import abc import abc
from ...core import entities as core_entities, app from ...core import app
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_truncators: list[typing.Type[Truncator]] = [] preregistered_truncators: list[typing.Type[Truncator]] = []
@@ -47,7 +47,7 @@ class Truncator(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def truncate(self, query: core_entities.Query) -> core_entities.Query: async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
"""截断 """截断
一般只需要操作query.messages也可以扩展操作query.prompt, query.user_message。 一般只需要操作query.messages也可以扩展操作query.prompt, query.user_message。

View File

@@ -1,14 +1,14 @@
from __future__ import annotations from __future__ import annotations
from .. import truncator from .. import truncator
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@truncator.truncator_class('round') @truncator.truncator_class('round')
class RoundTruncator(truncator.Truncator): class RoundTruncator(truncator.Truncator):
"""前文回合数阶段器""" """前文回合数阶段器"""
async def truncate(self, query: core_entities.Query) -> core_entities.Query: async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
"""截断""" """截断"""
max_round = query.pipeline_config['ai']['local-agent']['max-round'] max_round = query.pipeline_config['ai']['local-agent']['max-round']

View File

@@ -5,7 +5,7 @@ import traceback
import sqlalchemy import sqlalchemy
from ..core import app, entities from ..core import app
from . import entities as pipeline_entities from . import entities as pipeline_entities
from ..entity.persistence import pipeline as persistence_pipeline from ..entity.persistence import pipeline as persistence_pipeline
from . import stage from . import stage
@@ -13,6 +13,9 @@ from ..platform.types import message as platform_message, events as platform_eve
from ..plugin import events from ..plugin import events
from ..utils import importutil from ..utils import importutil
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import ( from . import (
resprule, resprule,
bansess, bansess,
@@ -75,11 +78,11 @@ class RuntimePipeline:
self.pipeline_entity = pipeline_entity self.pipeline_entity = pipeline_entity
self.stage_containers = stage_containers self.stage_containers = stage_containers
async def run(self, query: entities.Query): async def run(self, query: pipeline_query.Query):
query.pipeline_config = self.pipeline_entity.config query.pipeline_config = self.pipeline_entity.config
await self.process_query(query) await self.process_query(query)
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult): async def _check_output(self, query: pipeline_query.Query, result: pipeline_entities.StageProcessResult):
"""检查输出""" """检查输出"""
if result.user_notice: if result.user_notice:
# 处理str类型 # 处理str类型
@@ -109,7 +112,7 @@ class RuntimePipeline:
async def _execute_from_stage( async def _execute_from_stage(
self, self,
stage_index: int, stage_index: int,
query: entities.Query, query: pipeline_query.Query,
): ):
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。 """从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
@@ -169,13 +172,13 @@ class RuntimePipeline:
i += 1 i += 1
async def process_query(self, query: entities.Query): async def process_query(self, query: pipeline_query.Query):
"""处理请求""" """处理请求"""
try: try:
# ======== 触发 MessageReceived 事件 ======== # ======== 触发 MessageReceived 事件 ========
event_type = ( event_type = (
events.PersonMessageReceived events.PersonMessageReceived
if query.launcher_type == entities.LauncherTypes.PERSON if query.launcher_type == provider_session.LauncherTypes.PERSON
else events.GroupMessageReceived else events.GroupMessageReceived
) )

View File

@@ -3,10 +3,11 @@ from __future__ import annotations
import asyncio import asyncio
import typing import typing
from ..core import entities
from ..platform import adapter as msadapter from ..platform import adapter as msadapter
from ..platform.types import message as platform_message from ..platform.types import message as platform_message
from ..platform.types import events as platform_events from ..platform.types import events as platform_events
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class QueryPool: class QueryPool:
@@ -16,7 +17,7 @@ class QueryPool:
pool_lock: asyncio.Lock pool_lock: asyncio.Lock
queries: list[entities.Query] queries: list[pipeline_query.Query]
condition: asyncio.Condition condition: asyncio.Condition
@@ -29,16 +30,16 @@ class QueryPool:
async def add_query( async def add_query(
self, self,
bot_uuid: str, bot_uuid: str,
launcher_type: entities.LauncherTypes, launcher_type: provider_session.LauncherTypes,
launcher_id: typing.Union[int, str], launcher_id: typing.Union[int, str],
sender_id: typing.Union[int, str], sender_id: typing.Union[int, str],
message_event: platform_events.MessageEvent, message_event: platform_events.MessageEvent,
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
adapter: msadapter.MessagePlatformAdapter, adapter: msadapter.MessagePlatformAdapter,
pipeline_uuid: typing.Optional[str] = None, pipeline_uuid: typing.Optional[str] = None,
) -> entities.Query: ) -> pipeline_query.Query:
async with self.condition: async with self.condition:
query = entities.Query( query = pipeline_query.Query(
bot_uuid=bot_uuid, bot_uuid=bot_uuid,
query_id=self.query_id_counter, query_id=self.query_id_counter,
launcher_type=launcher_type, launcher_type=launcher_type,

View File

@@ -3,10 +3,10 @@ from __future__ import annotations
import datetime import datetime
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities
from langbot_plugin.api.entities.builtin.provider import message as provider_message from langbot_plugin.api.entities.builtin.provider import message as provider_message
from ...plugin import events from ...plugin import events
from ...platform.types import message as platform_message from ...platform.types import message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('PreProcessor') @stage.stage_class('PreProcessor')
@@ -26,7 +26,7 @@ class PreProcessor(stage.PipelineStage):
async def process( async def process(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
stage_inst_name: str, stage_inst_name: str,
) -> entities.StageProcessResult: ) -> entities.StageProcessResult:
"""处理""" """处理"""

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
import abc import abc
from ...core import app from ...core import app
from ...core import entities as core_entities
from .. import entities from .. import entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class MessageHandler(metaclass=abc.ABCMeta): class MessageHandler(metaclass=abc.ABCMeta):
@@ -19,7 +19,7 @@ class MessageHandler(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def handle( async def handle(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.StageProcessResult: ) -> entities.StageProcessResult:
raise NotImplementedError raise NotImplementedError

View File

@@ -6,13 +6,15 @@ import traceback
from .. import handler from .. import handler
from ... import entities from ... import entities
from ....core import entities as core_entities
from ....provider import runner as runner_module from ....provider import runner as runner_module
from ....plugin import events from ....plugin import events
from ....platform.types import message as platform_message from ....platform.types import message as platform_message
from ....utils import importutil from ....utils import importutil
from ....provider import runners from ....provider import runners
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
importutil.import_modules_in_pkg(runners) importutil.import_modules_in_pkg(runners)
@@ -20,7 +22,7 @@ importutil.import_modules_in_pkg(runners)
class ChatMessageHandler(handler.MessageHandler): class ChatMessageHandler(handler.MessageHandler):
async def handle( async def handle(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]: ) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理""" """处理"""
# 调API # 调API
@@ -29,7 +31,7 @@ class ChatMessageHandler(handler.MessageHandler):
# 触发插件事件 # 触发插件事件
event_class = ( event_class = (
events.PersonNormalMessageReceived events.PersonNormalMessageReceived
if query.launcher_type == core_entities.LauncherTypes.PERSON if query.launcher_type == provider_session.LauncherTypes.PERSON
else events.GroupNormalMessageReceived else events.GroupNormalMessageReceived
) )

View File

@@ -4,16 +4,17 @@ import typing
from .. import handler from .. import handler
from ... import entities from ... import entities
from ....core import entities as core_entities
from langbot_plugin.api.entities.builtin.provider import message as provider_message from langbot_plugin.api.entities.builtin.provider import message as provider_message
from ....plugin import events from ....plugin import events
from ....platform.types import message as platform_message from ....platform.types import message as platform_message
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class CommandHandler(handler.MessageHandler): class CommandHandler(handler.MessageHandler):
async def handle( async def handle(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]: ) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理""" """处理"""
@@ -28,7 +29,7 @@ class CommandHandler(handler.MessageHandler):
event_class = ( event_class = (
events.PersonCommandSent events.PersonCommandSent
if query.launcher_type == core_entities.LauncherTypes.PERSON if query.launcher_type == provider_session.LauncherTypes.PERSON
else events.GroupCommandSent else events.GroupCommandSent
) )

View File

@@ -1,10 +1,10 @@
from __future__ import annotations from __future__ import annotations
from ...core import entities as core_entities
from . import handler from . import handler
from .handlers import chat, command from .handlers import chat, command
from .. import entities from .. import entities
from .. import stage from .. import stage
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('MessageProcessor') @stage.stage_class('MessageProcessor')
@@ -30,7 +30,7 @@ class Processor(stage.PipelineStage):
async def process( async def process(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
stage_inst_name: str, stage_inst_name: str,
) -> entities.StageProcessResult: ) -> entities.StageProcessResult:
"""处理""" """处理"""

View File

@@ -2,7 +2,8 @@ from __future__ import annotations
import abc import abc
import typing import typing
from ...core import app, entities as core_entities from ...core import app
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = [] preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
@@ -33,7 +34,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def require_access( async def require_access(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
launcher_type: str, launcher_type: str,
launcher_id: typing.Union[int, str], launcher_id: typing.Union[int, str],
) -> bool: ) -> bool:
@@ -53,7 +54,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def release_access( async def release_access(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
launcher_type: str, launcher_type: str,
launcher_id: typing.Union[int, str], launcher_id: typing.Union[int, str],
): ):

View File

@@ -3,7 +3,7 @@ import asyncio
import time import time
import typing import typing
from .. import algo from .. import algo
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
# 固定窗口算法 # 固定窗口算法
@@ -32,7 +32,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
async def require_access( async def require_access(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
launcher_type: str, launcher_type: str,
launcher_id: typing.Union[int, str], launcher_id: typing.Union[int, str],
) -> bool: ) -> bool:
@@ -91,7 +91,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
async def release_access( async def release_access(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
launcher_type: str, launcher_type: str,
launcher_id: typing.Union[int, str], launcher_id: typing.Union[int, str],
): ):

View File

@@ -4,9 +4,10 @@ import typing
from .. import entities, stage from .. import entities, stage
from . import algo from . import algo
from ...core import entities as core_entities
from ...utils import importutil from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import algos from . import algos
importutil.import_modules_in_pkg(algos) importutil.import_modules_in_pkg(algos)
@@ -39,7 +40,7 @@ class RateLimit(stage.PipelineStage):
async def process( async def process(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
stage_inst_name: str, stage_inst_name: str,
) -> typing.Union[ ) -> typing.Union[
entities.StageProcessResult, entities.StageProcessResult,

View File

@@ -8,14 +8,14 @@ from ...platform.types import events as platform_events
from ...platform.types import message as platform_message from ...platform.types import message as platform_message
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('SendResponseBackStage') @stage.stage_class('SendResponseBackStage')
class SendResponseBackStage(stage.PipelineStage): class SendResponseBackStage(stage.PipelineStage):
"""发送响应消息""" """发送响应消息"""
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理""" """处理"""
random_range = ( random_range = (

View File

@@ -1,4 +1,4 @@
import pydantic.v1 as pydantic import pydantic
from ...platform.types import message as platform_message from ...platform.types import message as platform_message

View File

@@ -4,9 +4,10 @@ from __future__ import annotations
from . import rule from . import rule
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities
from ...utils import importutil from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import rules from . import rules
importutil.import_modules_in_pkg(rules) importutil.import_modules_in_pkg(rules)
@@ -32,7 +33,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
await rule_inst.initialize() await rule_inst.initialize()
self.rule_matchers.append(rule_inst) self.rule_matchers.append(rule_inst)
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
if query.launcher_type.value != 'group': # 只处理群消息 if query.launcher_type.value != 'group': # 只处理群消息
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -2,10 +2,11 @@ from __future__ import annotations
import abc import abc
import typing import typing
from ...core import app, entities as core_entities from ...core import app
from . import entities from . import entities
from ...platform.types import message as platform_message from ...platform.types import message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregisetered_rules: list[typing.Type[GroupRespondRule]] = [] preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
@@ -39,7 +40,7 @@ class GroupRespondRule(metaclass=abc.ABCMeta):
message_text: str, message_text: str,
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
"""判断消息是否匹配规则""" """判断消息是否匹配规则"""
raise NotImplementedError raise NotImplementedError

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities
from ....platform.types import message as platform_message from ....platform.types import message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@rule_model.rule_class('at-bot') @rule_model.rule_class('at-bot')
@@ -14,7 +14,7 @@ class AtBotRule(rule_model.GroupRespondRule):
message_text: str, message_text: str,
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']: if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
message_chain.remove(platform_message.At(query.adapter.bot_account_id)) message_chain.remove(platform_message.At(query.adapter.bot_account_id))

View File

@@ -1,7 +1,7 @@
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities
from ....platform.types import message as platform_message from ....platform.types import message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@rule_model.rule_class('prefix') @rule_model.rule_class('prefix')
@@ -11,7 +11,7 @@ class PrefixRule(rule_model.GroupRespondRule):
message_text: str, message_text: str,
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
prefixes = rule_dict['prefix'] prefixes = rule_dict['prefix']

View File

@@ -3,8 +3,8 @@ import random
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities
from ....platform.types import message as platform_message from ....platform.types import message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@rule_model.rule_class('random') @rule_model.rule_class('random')
@@ -14,7 +14,7 @@ class RandomRespRule(rule_model.GroupRespondRule):
message_text: str, message_text: str,
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
random_rate = rule_dict['random'] random_rate = rule_dict['random']

View File

@@ -3,8 +3,8 @@ import re
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities
from ....platform.types import message as platform_message from ....platform.types import message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@rule_model.rule_class('regexp') @rule_model.rule_class('regexp')
@@ -14,7 +14,7 @@ class RegExpRule(rule_model.GroupRespondRule):
message_text: str, message_text: str,
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
regexps = rule_dict['regexp'] regexps = rule_dict['regexp']

View File

@@ -3,8 +3,9 @@ from __future__ import annotations
import abc import abc
import typing import typing
from ..core import app, entities as core_entities from ..core import app
from . import entities from . import entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_stages: dict[str, type[PipelineStage]] = {} preregistered_stages: dict[str, type[PipelineStage]] = {}
@@ -33,7 +34,7 @@ class PipelineStage(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def process( async def process(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
stage_inst_name: str, stage_inst_name: str,
) -> typing.Union[ ) -> typing.Union[
entities.StageProcessResult, entities.StageProcessResult,

View File

@@ -2,12 +2,11 @@ from __future__ import annotations
import typing import typing
from ...core import entities as core_entities
from .. import entities from .. import entities
from .. import stage from .. import stage
from ...plugin import events from ...plugin import events
from ...platform.types import message as platform_message from ...platform.types import message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('ResponseWrapper') @stage.stage_class('ResponseWrapper')
@@ -25,7 +24,7 @@ class ResponseWrapper(stage.PipelineStage):
async def process( async def process(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
stage_inst_name: str, stage_inst_name: str,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]: ) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理""" """处理"""

View File

@@ -3,15 +3,14 @@ from __future__ import annotations
# MessageSource的适配器 # MessageSource的适配器
import typing import typing
import abc import abc
import pydantic
from ..core import app
from .types import message as platform_message from .types import message as platform_message
from .types import events as platform_events from .types import events as platform_events
from .logger import EventLogger from .logger import EventLogger
class MessagePlatformAdapter(metaclass=abc.ABCMeta): class MessagePlatformAdapter(pydantic.BaseModel, metaclass=abc.ABCMeta):
"""消息平台适配器基类""" """消息平台适配器基类"""
name: str name: str
@@ -21,11 +20,9 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
config: dict config: dict
ap: app.Application logger: EventLogger = pydantic.Field(exclude=True)
logger: EventLogger def __init__(self, config: dict, logger: EventLogger):
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
"""初始化适配器 """初始化适配器
Args: Args:
@@ -33,7 +30,6 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
ap (app.Application): 应用上下文 ap (app.Application): 应用上下文
""" """
self.config = config self.config = config
self.ap = ap
self.logger = logger self.logger = logger
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):

View File

@@ -19,6 +19,8 @@ from ..entity.errors import platform as platform_errors
from .logger import EventLogger from .logger import EventLogger
import langbot_plugin.api.entities.builtin.provider.session as provider_session
# 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题 # 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题
from . import types as mirai from . import types as mirai
@@ -73,7 +75,7 @@ class RuntimeBot:
await self.ap.query_pool.add_query( await self.ap.query_pool.add_query(
bot_uuid=self.bot_entity.uuid, bot_uuid=self.bot_entity.uuid,
launcher_type=core_entities.LauncherTypes.PERSON, launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=event.sender.id, launcher_id=event.sender.id,
sender_id=event.sender.id, sender_id=event.sender.id,
message_event=event, message_event=event,
@@ -98,7 +100,7 @@ class RuntimeBot:
await self.ap.query_pool.add_query( await self.ap.query_pool.add_query(
bot_uuid=self.bot_entity.uuid, bot_uuid=self.bot_entity.uuid,
launcher_type=core_entities.LauncherTypes.GROUP, launcher_type=provider_session.LauncherTypes.GROUP,
launcher_id=event.group.id, launcher_id=event.group.id,
sender_id=event.sender.id, sender_id=event.sender.id,
message_event=event, message_event=event,
@@ -172,9 +174,9 @@ class PlatformManager:
webchat_logger = EventLogger(name='webchat-adapter', ap=self.ap) webchat_logger = EventLogger(name='webchat-adapter', ap=self.ap)
webchat_adapter_inst = webchat_adapter_class( webchat_adapter_inst = webchat_adapter_class(
{}, {},
self.ap,
webchat_logger, webchat_logger,
) )
webchat_adapter_inst.ap = self.ap
self.webchat_proxy_bot = RuntimeBot( self.webchat_proxy_bot = RuntimeBot(
ap=self.ap, ap=self.ap,
@@ -231,7 +233,6 @@ class PlatformManager:
adapter_inst = self.adapter_dict[bot_entity.adapter]( adapter_inst = self.adapter_dict[bot_entity.adapter](
bot_entity.adapter_config, bot_entity.adapter_config,
self.ap,
logger, logger,
) )

View File

@@ -7,7 +7,6 @@ import datetime
import aiocqhttp import aiocqhttp
from .. import adapter from .. import adapter
from ...core import app
from ..types import message as platform_message from ..types import message as platform_message
from ..types import events as platform_events from ..types import events as platform_events
from ..types import entities as platform_entities from ..types import entities as platform_entities
@@ -273,11 +272,9 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
config: dict config: dict
ap: app.Application
on_websocket_connection_event_cache: typing.List[typing.Callable[[aiocqhttp.Event], None]] = [] on_websocket_connection_event_cache: typing.List[typing.Callable[[aiocqhttp.Event], None]] = []
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: EventLogger):
self.config = config self.config = config
self.logger = logger self.logger = logger
@@ -287,7 +284,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
self.config['shutdown_trigger'] = shutdown_trigger_placeholder self.config['shutdown_trigger'] = shutdown_trigger_placeholder
self.ap = ap
self.on_websocket_connection_event_cache = [] self.on_websocket_connection_event_cache = []
if 'access-token' in config: if 'access-token' in config:

View File

@@ -4,7 +4,6 @@ from libs.dingtalk_api.dingtalkevent import DingTalkEvent
from pkg.platform.types import message as platform_message from pkg.platform.types import message as platform_message
from pkg.platform.adapter import MessagePlatformAdapter from pkg.platform.adapter import MessagePlatformAdapter
from .. import adapter from .. import adapter
from ...core import app
from ..types import events as platform_events from ..types import events as platform_events
from ..types import entities as platform_entities from ..types import entities as platform_entities
from libs.dingtalk_api.api import DingTalkClient from libs.dingtalk_api.api import DingTalkClient
@@ -94,15 +93,13 @@ class DingTalkEventConverter(adapter.EventConverter):
class DingTalkAdapter(adapter.MessagePlatformAdapter): class DingTalkAdapter(adapter.MessagePlatformAdapter):
bot: DingTalkClient bot: DingTalkClient
ap: app.Application
bot_account_id: str bot_account_id: str
message_converter: DingTalkMessageConverter = DingTalkMessageConverter() message_converter: DingTalkMessageConverter = DingTalkMessageConverter()
event_converter: DingTalkEventConverter = DingTalkEventConverter() event_converter: DingTalkEventConverter = DingTalkEventConverter()
config: dict config: dict
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: EventLogger):
self.config = config self.config = config
self.ap = ap
self.logger = logger self.logger = logger
required_keys = [ required_keys = [
'client_id', 'client_id',

View File

@@ -12,7 +12,6 @@ import datetime
import aiohttp import aiohttp
from .. import adapter from .. import adapter
from ...core import app
from ..types import message as platform_message from ..types import message as platform_message
from ..types import events as platform_events from ..types import events as platform_events
from ..types import entities as platform_entities from ..types import entities as platform_entities
@@ -161,8 +160,6 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
config: dict config: dict
ap: app.Application
message_converter: DiscordMessageConverter = DiscordMessageConverter() message_converter: DiscordMessageConverter = DiscordMessageConverter()
event_converter: DiscordEventConverter = DiscordEventConverter() event_converter: DiscordEventConverter = DiscordEventConverter()
@@ -171,9 +168,8 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
] = {} ] = {}
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: EventLogger):
self.config = config self.config = config
self.ap = ap
self.logger = logger self.logger = logger
self.bot_account_id = self.config['client_id'] self.bot_account_id = self.config['client_id']

View File

@@ -19,7 +19,6 @@ import quart
from lark_oapi.api.im.v1 import * from lark_oapi.api.im.v1 import *
from .. import adapter from .. import adapter
from ...core import app
from ..types import message as platform_message from ..types import message as platform_message
from ..types import events as platform_events from ..types import events as platform_events
from ..types import entities as platform_entities from ..types import entities as platform_entities
@@ -337,11 +336,9 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
config: dict config: dict
quart_app: quart.Quart quart_app: quart.Quart
ap: app.Application
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: EventLogger):
self.config = config self.config = config
self.ap = ap
self.logger = logger self.logger = logger
self.quart_app = quart.Quart(__name__) self.quart_app = quart.Quart(__name__)
self.listeners = {} self.listeners = {}
@@ -351,8 +348,6 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
try: try:
data = await quart.request.json data = await quart.request.json
self.ap.logger.debug(f'Lark callback event: {data}')
if 'encrypt' in data: if 'encrypt' in data:
cipher = AESCipher(self.config['encrypt-key']) cipher = AESCipher(self.config['encrypt-key'])
data = cipher.decrypt_string(data['encrypt']) data = cipher.decrypt_string(data['encrypt'])

View File

Before

Width:  |  Height:  |  Size: 25 KiB

After

Width:  |  Height:  |  Size: 25 KiB

View File

@@ -11,16 +11,16 @@ import threading
import quart import quart
import aiohttp import aiohttp
from .. import adapter from ... import adapter
from ...core import app from ....core import app
from ..types import message as platform_message from ...types import message as platform_message
from ..types import events as platform_events from ...types import events as platform_events
from ..types import entities as platform_entities from ...types import entities as platform_entities
from ...utils import image from ....utils import image
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from typing import Optional, Tuple from typing import Optional, Tuple
from functools import partial from functools import partial
from ..logger import EventLogger from ...logger import EventLogger
class GewechatMessageConverter(adapter.MessageConverter): class GewechatMessageConverter(adapter.MessageConverter):
@@ -491,7 +491,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
async def gewechat_callback(): async def gewechat_callback():
data = await quart.request.json data = await quart.request.json
# print(json.dumps(data, indent=4, ensure_ascii=False)) # print(json.dumps(data, indent=4, ensure_ascii=False))
self.ap.logger.debug(f'Gewechat callback event: {data}') await self.logger.debug(f'Gewechat callback event: {data}')
if 'data' in data: if 'data' in data:
data['Data'] = data['data'] data['Data'] = data['data']
@@ -601,7 +601,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
if handler := handler_map.get(msg['type']): if handler := handler_map.get(msg['type']):
handler(msg) handler(msg)
else: else:
self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}') await self.logger.warning(f'未处理的消息类型: {msg["type"]}')
continue continue
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
@@ -656,9 +656,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
self.config['app_id'] = app_id self.config['app_id'] = app_id
self.ap.logger.info(f'Gewechat 登录成功app_id: {app_id}') print(f'Gewechat 登录成功app_id: {app_id}')
self.ap.platform_mgr.write_back_config('gewechat', self, self.config)
# 获取 nickname # 获取 nickname
profile = self.bot.get_profile(self.config['app_id']) profile = self.bot.get_profile(self.config['app_id'])

View File

Before

Width:  |  Height:  |  Size: 274 KiB

After

Width:  |  Height:  |  Size: 274 KiB

View File

@@ -9,12 +9,12 @@ import traceback
import nakuru import nakuru
import nakuru.entities.components as nkc import nakuru.entities.components as nkc
from .. import adapter as adapter_model from ... import adapter as adapter_model
from ...pipeline.longtext.strategies import forward from ....pipeline.longtext.strategies import forward
from ...platform.types import message as platform_message from ...types import message as platform_message
from ...platform.types import entities as platform_entities from ...types import entities as platform_entities
from ...platform.types import events as platform_events from ...types import events as platform_events
from ..logger import EventLogger from ...logger import EventLogger
class NakuruProjectMessageConverter(adapter_model.MessageConverter): class NakuruProjectMessageConverter(adapter_model.MessageConverter):
@@ -262,7 +262,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
source_cls = NakuruProjectEventConverter.yiri2target(event_type) source_cls = NakuruProjectEventConverter.yiri2target(event_type)
# 包装函数 # 包装函数
async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls): async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls): # type: ignore
await callback(self.event_converter.target2yiri(source), self) await callback(self.event_converter.target2yiri(source), self)
# 将包装函数和原函数的对应关系存入列表 # 将包装函数和原函数的对应关系存入列表
@@ -322,7 +322,6 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
except Exception: except Exception:
raise Exception('获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确') raise Exception('获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确')
await self.bot._run() await self.bot._run()
self.ap.logger.info('运行 Nakuru 适配器')
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)

View File

@@ -10,14 +10,14 @@ import botpy
import botpy.message as botpy_message import botpy.message as botpy_message
import botpy.types.message as botpy_message_type import botpy.types.message as botpy_message_type
from .. import adapter as adapter_model from ... import adapter as adapter_model
from ...pipeline.longtext.strategies import forward from ....pipeline.longtext.strategies import forward
from ...core import app from ....core import app
from ...config import manager as cfg_mgr from ....config import manager as cfg_mgr
from ...platform.types import entities as platform_entities from ...types import entities as platform_entities
from ...platform.types import events as platform_events from ...types import events as platform_events
from ...platform.types import message as platform_message from ...types import message as platform_message
from ..logger import EventLogger from ...logger import EventLogger
class OfficialGroupMessage(platform_events.GroupMessage): class OfficialGroupMessage(platform_events.GroupMessage):
@@ -519,7 +519,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
self.cfg['ret_coro'] = True self.cfg['ret_coro'] = True
self.ap.logger.info('运行 QQ 官方适配器') await self.logger.info('运行 QQ 官方适配器')
await (await self.bot.start(**self.cfg)) await (await self.bot.start(**self.cfg))
async def kill(self) -> bool: async def kill(self) -> bool:

View File

Before

Width:  |  Height:  |  Size: 1.7 KiB

After

Width:  |  Height:  |  Size: 1.7 KiB

View File

@@ -10,7 +10,6 @@ from libs.official_account_api.oaevent import OAEvent
from libs.official_account_api.api import OAClient from libs.official_account_api.api import OAClient
from libs.official_account_api.api import OAClientForLongerResponse from libs.official_account_api.api import OAClientForLongerResponse
from .. import adapter from .. import adapter
from ...core import app
from ..types import entities as platform_entities from ..types import entities as platform_entities
from ...command.errors import ParamNotEnoughError from ...command.errors import ParamNotEnoughError
from ..logger import EventLogger from ..logger import EventLogger
@@ -58,15 +57,13 @@ class OAEventConverter(adapter.EventConverter):
class OfficialAccountAdapter(adapter.MessagePlatformAdapter): class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
bot: OAClient | OAClientForLongerResponse bot: OAClient | OAClientForLongerResponse
ap: app.Application
bot_account_id: str bot_account_id: str
message_converter: OAMessageConverter = OAMessageConverter() message_converter: OAMessageConverter = OAMessageConverter()
event_converter: OAEventConverter = OAEventConverter() event_converter: OAEventConverter = OAEventConverter()
config: dict config: dict
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: EventLogger):
self.config = config self.config = config
self.ap = ap
self.logger = logger self.logger = logger
required_keys = [ required_keys = [

View File

@@ -8,7 +8,6 @@ import datetime
from pkg.platform.adapter import MessagePlatformAdapter from pkg.platform.adapter import MessagePlatformAdapter
from pkg.platform.types import events as platform_events, message as platform_message from pkg.platform.types import events as platform_events, message as platform_message
from .. import adapter from .. import adapter
from ...core import app
from ..types import entities as platform_entities from ..types import entities as platform_entities
from ...command.errors import ParamNotEnoughError from ...command.errors import ParamNotEnoughError
from libs.qq_official_api.api import QQOfficialClient from libs.qq_official_api.api import QQOfficialClient
@@ -134,15 +133,13 @@ class QQOfficialEventConverter(adapter.EventConverter):
class QQOfficialAdapter(adapter.MessagePlatformAdapter): class QQOfficialAdapter(adapter.MessagePlatformAdapter):
bot: QQOfficialClient bot: QQOfficialClient
ap: app.Application
config: dict config: dict
bot_account_id: str bot_account_id: str
message_converter: QQOfficialMessageConverter = QQOfficialMessageConverter() message_converter: QQOfficialMessageConverter = QQOfficialMessageConverter()
event_converter: QQOfficialEventConverter = QQOfficialEventConverter() event_converter: QQOfficialEventConverter = QQOfficialEventConverter()
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: EventLogger):
self.config = config self.config = config
self.ap = ap
self.logger = logger self.logger = logger
required_keys = [ required_keys = [

View File

@@ -9,7 +9,6 @@ from libs.slack_api.api import SlackClient
from pkg.platform.adapter import MessagePlatformAdapter from pkg.platform.adapter import MessagePlatformAdapter
from pkg.platform.types import events as platform_events, message as platform_message from pkg.platform.types import events as platform_events, message as platform_message
from libs.slack_api.slackevent import SlackEvent from libs.slack_api.slackevent import SlackEvent
from pkg.core import app
from .. import adapter from .. import adapter
from ..types import entities as platform_entities from ..types import entities as platform_entities
from ...command.errors import ParamNotEnoughError from ...command.errors import ParamNotEnoughError
@@ -86,15 +85,13 @@ class SlackEventConverter(adapter.EventConverter):
class SlackAdapter(adapter.MessagePlatformAdapter): class SlackAdapter(adapter.MessagePlatformAdapter):
bot: SlackClient bot: SlackClient
ap: app.Application
bot_account_id: str bot_account_id: str
message_converter: SlackMessageConverter = SlackMessageConverter() message_converter: SlackMessageConverter = SlackMessageConverter()
event_converter: SlackEventConverter = SlackEventConverter() event_converter: SlackEventConverter = SlackEventConverter()
config: dict config: dict
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: EventLogger):
self.config = config self.config = config
self.ap = ap
self.logger = logger self.logger = logger
required_keys = [ required_keys = [
'bot_token', 'bot_token',

View File

@@ -10,10 +10,7 @@ import traceback
import base64 import base64
import aiohttp import aiohttp
from lark_oapi.api.im.v1 import *
from .. import adapter from .. import adapter
from ...core import app
from ..types import message as platform_message from ..types import message as platform_message
from ..types import events as platform_events from ..types import events as platform_events
from ..types import entities as platform_entities from ..types import entities as platform_entities
@@ -141,16 +138,14 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
event_converter: TelegramEventConverter = TelegramEventConverter() event_converter: TelegramEventConverter = TelegramEventConverter()
config: dict config: dict
ap: app.Application
listeners: typing.Dict[ listeners: typing.Dict[
typing.Type[platform_events.Event], typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
] = {} ] = {}
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: EventLogger):
self.config = config self.config = config
self.ap = ap
self.logger = logger self.logger = logger
async def telegram_callback(update: Update, context: ContextTypes.DEFAULT_TYPE): async def telegram_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):

View File

@@ -44,13 +44,14 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter):
webchat_person_session: WebChatSession webchat_person_session: WebChatSession
webchat_group_session: WebChatSession webchat_group_session: WebChatSession
ap: app.Application # set by bot manager
listeners: typing.Dict[ listeners: typing.Dict[
typing.Type[platform_events.Event], typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], None], typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], None],
] = {} ] = {}
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: EventLogger):
self.ap = ap
self.logger = logger self.logger = logger
self.config = config self.config = config

View File

@@ -488,6 +488,8 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
ap: app.Application ap: app.Application
logger: EventLogger
message_converter: WeChatPadMessageConverter message_converter: WeChatPadMessageConverter
event_converter: WeChatPadEventConverter event_converter: WeChatPadEventConverter
@@ -507,8 +509,6 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
async def ws_message(self, data): async def ws_message(self, data):
"""处理接收到的消息""" """处理接收到的消息"""
# self.ap.logger.debug(f"Gewechat callback event: {data}")
# print(data)
try: try:
event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id) event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id)
@@ -571,9 +571,8 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
if handler := handler_map.get(msg['type']): if handler := handler_map.get(msg['type']):
handler(msg) handler(msg)
# self.ap.logger.warning(f"未处理的消息类型: {ret}")
else: else:
self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}') print(f'未处理的消息类型: {msg["type"]}')
continue continue
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
@@ -615,7 +614,6 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
if self.config['token']: if self.config['token']:
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token']) self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'])
data = self.bot.get_login_status() data = self.bot.get_login_status()
self.ap.logger.info(data)
if data['Code'] == 300 and data['Text'] == '你已退出微信': if data['Code'] == 300 and data['Text'] == '你已退出微信':
response = requests.post( response = requests.post(
f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}', f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}',
@@ -635,7 +633,7 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
self.config['token'] = response.json()['Data'][0] self.config['token'] = response.json()['Data'][0]
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'], logger=self.logger) self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'], logger=self.logger)
self.ap.logger.info(self.config['token']) await self.logger.info(self.config['token'])
thread_1 = threading.Event() thread_1 = threading.Event()
def wechat_login_process(): def wechat_login_process():
@@ -643,10 +641,9 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
# login_data =self.bot.get_login_qr() # login_data =self.bot.get_login_qr()
# url = login_data['Data']["QrCodeUrl"] # url = login_data['Data']["QrCodeUrl"]
# self.ap.logger.info(login_data)
profile = self.bot.get_profile() profile = self.bot.get_profile()
self.ap.logger.info(profile) self.logger.info(profile)
self.bot_account_id = profile['Data']['userInfo']['nickName']['str'] self.bot_account_id = profile['Data']['userInfo']['nickName']['str']
self.config['wxid'] = profile['Data']['userInfo']['userName']['str'] self.config['wxid'] = profile['Data']['userInfo']['userName']['str']
@@ -658,27 +655,26 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
def connect_websocket_sync() -> None: def connect_websocket_sync() -> None:
thread_1.wait() thread_1.wait()
uri = f'{self.config["wechatpad_ws"]}/GetSyncMsg?key={self.config["token"]}' uri = f'{self.config["wechatpad_ws"]}/GetSyncMsg?key={self.config["token"]}'
self.ap.logger.info(f'Connecting to WebSocket: {uri}') print(f'Connecting to WebSocket: {uri}')
def on_message(ws, message): def on_message(ws, message):
try: try:
data = json.loads(message) data = json.loads(message)
self.ap.logger.debug(f'Received message: {data}')
# 这里需要确保ws_message是同步的或者使用asyncio.run调用异步方法 # 这里需要确保ws_message是同步的或者使用asyncio.run调用异步方法
asyncio.run(self.ws_message(data)) asyncio.run(self.ws_message(data))
except json.JSONDecodeError: except json.JSONDecodeError:
self.ap.logger.error(f'Non-JSON message: {message[:100]}...') print(f'Non-JSON message: {message[:100]}...')
def on_error(ws, error): def on_error(ws, error):
self.ap.logger.error(f'WebSocket error: {str(error)[:200]}') print(f'WebSocket error: {str(error)[:200]}')
def on_close(ws, close_status_code, close_msg): def on_close(ws, close_status_code, close_msg):
self.ap.logger.info('WebSocket closed, reconnecting...') print('WebSocket closed, reconnecting...')
time.sleep(5) time.sleep(5)
connect_websocket_sync() # 自动重连 connect_websocket_sync() # 自动重连
def on_open(ws): def on_open(ws):
self.ap.logger.info('WebSocket connected successfully!') print('WebSocket connected successfully!')
ws = websocket.WebSocketApp( ws = websocket.WebSocketApp(
uri, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open uri, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open
@@ -689,10 +685,9 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
# connect_websocket_sync() # connect_websocket_sync()
# 这行代码会在WebSocket连接断开后才会执行 # 这行代码会在WebSocket连接断开后才会执行
# self.ap.logger.info("WebSocket client thread started")
thread = threading.Thread(target=connect_websocket_sync, name='WebSocketClientThread', daemon=True) thread = threading.Thread(target=connect_websocket_sync, name='WebSocketClientThread', daemon=True)
thread.start() thread.start()
self.ap.logger.info('WebSocket client thread started') self.logger.info('WebSocket client thread started')
async def kill(self) -> bool: async def kill(self) -> bool:
pass pass

View File

@@ -10,7 +10,6 @@ from pkg.platform.adapter import MessagePlatformAdapter
from pkg.platform.types import events as platform_events, message as platform_message from pkg.platform.types import events as platform_events, message as platform_message
from libs.wecom_api.wecomevent import WecomEvent from libs.wecom_api.wecomevent import WecomEvent
from .. import adapter from .. import adapter
from ...core import app
from ..types import entities as platform_entities from ..types import entities as platform_entities
from ...command.errors import ParamNotEnoughError from ...command.errors import ParamNotEnoughError
from ...utils import image from ...utils import image
@@ -129,15 +128,13 @@ class WecomEventConverter:
class WecomAdapter(adapter.MessagePlatformAdapter): class WecomAdapter(adapter.MessagePlatformAdapter):
bot: WecomClient bot: WecomClient
ap: app.Application
bot_account_id: str bot_account_id: str
message_converter: WecomMessageConverter = WecomMessageConverter() message_converter: WecomMessageConverter = WecomMessageConverter()
event_converter: WecomEventConverter = WecomEventConverter() event_converter: WecomEventConverter = WecomEventConverter()
config: dict config: dict
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: EventLogger):
self.config = config self.config = config
self.ap = ap
self.logger = logger self.logger = logger
required_keys = [ required_keys = [

View File

@@ -9,7 +9,6 @@ from libs.wecom_customer_service_api.api import WecomCSClient
from pkg.platform.adapter import MessagePlatformAdapter from pkg.platform.adapter import MessagePlatformAdapter
from pkg.platform.types import events as platform_events, message as platform_message from pkg.platform.types import events as platform_events, message as platform_message
from libs.wecom_customer_service_api.wecomcsevent import WecomCSEvent from libs.wecom_customer_service_api.wecomcsevent import WecomCSEvent
from pkg.core import app
from .. import adapter from .. import adapter
from ..types import entities as platform_entities from ..types import entities as platform_entities
from ...command.errors import ParamNotEnoughError from ...command.errors import ParamNotEnoughError
@@ -119,15 +118,13 @@ class WecomEventConverter:
class WecomCSAdapter(adapter.MessagePlatformAdapter): class WecomCSAdapter(adapter.MessagePlatformAdapter):
bot: WecomCSClient bot: WecomCSClient
ap: app.Application
bot_account_id: str bot_account_id: str
message_converter: WecomMessageConverter = WecomMessageConverter() message_converter: WecomMessageConverter = WecomMessageConverter()
event_converter: WecomEventConverter = WecomEventConverter() event_converter: WecomEventConverter = WecomEventConverter()
config: dict config: dict
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: EventLogger):
self.config = config self.config = config
self.ap = ap
self.logger = logger self.logger = logger
required_keys = [ required_keys = [

View File

@@ -4,16 +4,16 @@ import typing
import pydantic.v1 as pydantic import pydantic.v1 as pydantic
from ..core import entities as core_entities
from ..provider import entities as llm_entities from ..provider import entities as llm_entities
from ..platform.types import message as platform_message from ..platform.types import message as platform_message
import langbot_plugin.api.entities.builtin.provider.session as provider_session import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class BaseEventModel(pydantic.BaseModel): class BaseEventModel(pydantic.BaseModel):
"""事件模型基类""" """事件模型基类"""
query: typing.Union[core_entities.Query, None] query: typing.Union[pipeline_query.Query, None]
"""此次请求的query对象非请求过程的事件时为None""" """此次请求的query对象非请求过程的事件时为None"""
class Config: class Config:

View File

@@ -6,10 +6,10 @@ import importlib
import traceback import traceback
from .. import loader, events, context, models from .. import loader, events, context, models
from ...core import entities as core_entities
from langbot_plugin.api.entities.builtin.resource import tool as resource_tool from langbot_plugin.api.entities.builtin.resource import tool as resource_tool
from ...utils import funcschema from ...utils import funcschema
from ...discover import engine as discover_engine from ...discover import engine as discover_engine
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class PluginLoader(loader.PluginLoader): class PluginLoader(loader.PluginLoader):
@@ -98,7 +98,7 @@ class PluginLoader(loader.PluginLoader):
function_schema = funcschema.get_func_schema(func) function_schema = funcschema.get_func_schema(func)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name) function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
async def handler(plugin: context.BasePlugin, query: core_entities.Query, *args, **kwargs): async def handler(plugin: context.BasePlugin, query: pipeline_query.Query, *args, **kwargs):
return func(*args, **kwargs) return func(*args, **kwargs)
llm_function = resource_tool.LLMTool( llm_function = resource_tool.LLMTool(

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import typing import typing
import pydantic.v1 as pydantic import pydantic
from pkg.provider import entities from pkg.provider import entities

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import typing import typing
import pydantic.v1 as pydantic import pydantic
from . import requester from . import requester
from . import token from . import token

View File

@@ -4,11 +4,11 @@ import abc
import typing import typing
from ...core import app from ...core import app
from ...core import entities as core_entities
from .. import entities as llm_entities from .. import entities as llm_entities
from ...entity.persistence import model as persistence_model from ...entity.persistence import model as persistence_model
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
from . import token from . import token
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class RuntimeLLMModel: class RuntimeLLMModel:
@@ -56,7 +56,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def invoke_llm( async def invoke_llm(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
model: RuntimeLLMModel, model: RuntimeLLMModel,
messages: typing.List[llm_entities.Message], messages: typing.List[llm_entities.Message],
funcs: typing.List[resource_tool.LLMTool] = None, funcs: typing.List[resource_tool.LLMTool] = None,

View File

@@ -9,10 +9,10 @@ import httpx
from .. import errors, requester from .. import errors, requester
from ....core import entities as core_entities
from ... import entities as llm_entities from ... import entities as llm_entities
from ....utils import image from ....utils import image
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class AnthropicMessages(requester.LLMAPIRequester): class AnthropicMessages(requester.LLMAPIRequester):
@@ -48,7 +48,7 @@ class AnthropicMessages(requester.LLMAPIRequester):
async def invoke_llm( async def invoke_llm(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
model: requester.RuntimeLLMModel, model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message], messages: typing.List[llm_entities.Message],
funcs: typing.List[resource_tool.LLMTool] = None, funcs: typing.List[resource_tool.LLMTool] = None,

View File

@@ -8,9 +8,9 @@ import openai.types.chat.chat_completion as chat_completion
import httpx import httpx
from .. import errors, requester from .. import errors, requester
from ....core import entities as core_entities
from ... import entities as llm_entities from ... import entities as llm_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class OpenAIChatCompletions(requester.LLMAPIRequester): class OpenAIChatCompletions(requester.LLMAPIRequester):
@@ -60,7 +60,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
async def _closure( async def _closure(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
req_messages: list[dict], req_messages: list[dict],
use_model: requester.RuntimeLLMModel, use_model: requester.RuntimeLLMModel,
use_funcs: list[resource_tool.LLMTool] = None, use_funcs: list[resource_tool.LLMTool] = None,
@@ -101,7 +101,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
async def invoke_llm( async def invoke_llm(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
model: requester.RuntimeLLMModel, model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message], messages: typing.List[llm_entities.Message],
funcs: typing.List[resource_tool.LLMTool] = None, funcs: typing.List[resource_tool.LLMTool] = None,

View File

@@ -4,9 +4,9 @@ import typing
from . import chatcmpl from . import chatcmpl
from .. import errors, requester from .. import errors, requester
from ....core import entities as core_entities
from ... import entities as llm_entities from ... import entities as llm_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
@@ -19,7 +19,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
async def _closure( async def _closure(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
req_messages: list[dict], req_messages: list[dict],
use_model: requester.RuntimeLLMModel, use_model: requester.RuntimeLLMModel,
use_funcs: list[resource_tool.LLMTool] = None, use_funcs: list[resource_tool.LLMTool] = None,

View File

@@ -5,9 +5,9 @@ import typing
from . import chatcmpl from . import chatcmpl
from .. import requester from .. import requester
from ....core import entities as core_entities
from ... import entities as llm_entities from ... import entities as llm_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
@@ -20,7 +20,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
async def _closure( async def _closure(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
req_messages: list[dict], req_messages: list[dict],
use_model: requester.RuntimeLLMModel, use_model: requester.RuntimeLLMModel,
use_funcs: list[resource_tool.LLMTool] = None, use_funcs: list[resource_tool.LLMTool] = None,

View File

@@ -9,9 +9,9 @@ import openai.types.chat.chat_completion_message_tool_call as chat_completion_me
import httpx import httpx
from .. import entities, errors, requester from .. import entities, errors, requester
from ....core import entities as core_entities
from ... import entities as llm_entities from ... import entities as llm_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class ModelScopeChatCompletions(requester.LLMAPIRequester): class ModelScopeChatCompletions(requester.LLMAPIRequester):
@@ -125,7 +125,7 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
async def _closure( async def _closure(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
req_messages: list[dict], req_messages: list[dict],
use_model: requester.RuntimeLLMModel, use_model: requester.RuntimeLLMModel,
use_funcs: list[resource_tool.LLMTool] = None, use_funcs: list[resource_tool.LLMTool] = None,
@@ -166,7 +166,7 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
async def invoke_llm( async def invoke_llm(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
model: entities.LLMModelInfo, model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message], messages: typing.List[llm_entities.Message],
funcs: typing.List[resource_tool.LLMTool] = None, funcs: typing.List[resource_tool.LLMTool] = None,

View File

@@ -5,9 +5,9 @@ import typing
from . import chatcmpl from . import chatcmpl
from .. import requester from .. import requester
from ....core import entities as core_entities
from ... import entities as llm_entities from ... import entities as llm_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
@@ -20,7 +20,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
async def _closure( async def _closure(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
req_messages: list[dict], req_messages: list[dict],
use_model: requester.RuntimeLLMModel, use_model: requester.RuntimeLLMModel,
use_funcs: list[resource_tool.LLMTool] = None, use_funcs: list[resource_tool.LLMTool] = None,

View File

@@ -12,7 +12,7 @@ import ollama
from .. import errors, requester from .. import errors, requester
from ... import entities as llm_entities from ... import entities as llm_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
REQUESTER_NAME: str = 'ollama-chat' REQUESTER_NAME: str = 'ollama-chat'
@@ -39,7 +39,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
async def _closure( async def _closure(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
req_messages: list[dict], req_messages: list[dict],
use_model: requester.RuntimeLLMModel, use_model: requester.RuntimeLLMModel,
use_funcs: list[resource_tool.LLMTool] = None, use_funcs: list[resource_tool.LLMTool] = None,
@@ -105,7 +105,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
async def invoke_llm( async def invoke_llm(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
model: requester.RuntimeLLMModel, model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message], messages: typing.List[llm_entities.Message],
funcs: typing.List[resource_tool.LLMTool] = None, funcs: typing.List[resource_tool.LLMTool] = None,

View File

@@ -3,8 +3,9 @@ from __future__ import annotations
import abc import abc
import typing import typing
from ..core import app, entities as core_entities from ..core import app
from . import entities as llm_entities from . import entities as llm_entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_runners: list[typing.Type[RequestRunner]] = [] preregistered_runners: list[typing.Type[RequestRunner]] = []
@@ -35,6 +36,6 @@ class RequestRunner(abc.ABC):
self.pipeline_config = pipeline_config self.pipeline_config = pipeline_config
@abc.abstractmethod @abc.abstractmethod
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求""" """运行请求"""
pass pass

View File

@@ -6,8 +6,9 @@ import re
import dashscope import dashscope
from .. import runner from .. import runner
from ...core import app, entities as core_entities from ...core import app
from .. import entities as llm_entities from .. import entities as llm_entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class DashscopeAPIError(Exception): class DashscopeAPIError(Exception):
@@ -65,7 +66,7 @@ class DashScopeAPIRunner(runner.RequestRunner):
# 使用 re.sub() 进行替换 # 使用 re.sub() 进行替换
return pattern.sub(replacement, text) return pattern.sub(replacement, text)
async def _preprocess_user_message(self, query: core_entities.Query) -> tuple[str, list[str]]: async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[str]]:
"""预处理用户消息,提取纯文本,阿里云提供的上传文件方法过于复杂,暂不支持上传文件(包括图片)""" """预处理用户消息,提取纯文本,阿里云提供的上传文件方法过于复杂,暂不支持上传文件(包括图片)"""
plain_text = '' plain_text = ''
image_ids = [] image_ids = []
@@ -89,7 +90,7 @@ class DashScopeAPIRunner(runner.RequestRunner):
return plain_text, image_ids return plain_text, image_ids
async def _agent_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: async def _agent_messages(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""Dashscope 智能体对话请求""" """Dashscope 智能体对话请求"""
# 局部变量 # 局部变量
@@ -147,7 +148,9 @@ class DashScopeAPIRunner(runner.RequestRunner):
content=pending_content, content=pending_content,
) )
async def _workflow_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: async def _workflow_messages(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""Dashscope 工作流对话请求""" """Dashscope 工作流对话请求"""
# 局部变量 # 局部变量
@@ -210,7 +213,7 @@ class DashScopeAPIRunner(runner.RequestRunner):
content=pending_content, content=pending_content,
) )
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行""" """运行"""
if self.app_type == 'agent': if self.app_type == 'agent':
async for msg in self._agent_messages(query): async for msg in self._agent_messages(query):

View File

@@ -8,10 +8,10 @@ import base64
from .. import runner from .. import runner
from ...core import app, entities as core_entities from ...core import app
from .. import entities as llm_entities from .. import entities as llm_entities
from ...utils import image from ...utils import image
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from libs.dify_service_api.v1 import client, errors from libs.dify_service_api.v1 import client, errors
@@ -62,7 +62,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL) content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL)
return f'<think>{thinking_text.group(1)}</think>\n{content_text}' return f'<think>{thinking_text.group(1)}</think>\n{content_text}'
async def _preprocess_user_message(self, query: core_entities.Query) -> tuple[str, list[str]]: async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[str]]:
"""预处理用户消息,提取纯文本,并将图片上传到 Dify 服务 """预处理用户消息,提取纯文本,并将图片上传到 Dify 服务
Returns: Returns:
@@ -90,7 +90,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
return plain_text, image_ids return plain_text, image_ids
async def _chat_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: async def _chat_messages(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""调用聊天助手""" """调用聊天助手"""
cov_id = query.session.using_conversation.uuid or '' cov_id = query.session.using_conversation.uuid or ''
query.variables['conversation_id'] = cov_id query.variables['conversation_id'] = cov_id
@@ -152,7 +152,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
query.session.using_conversation.uuid = chunk['conversation_id'] query.session.using_conversation.uuid = chunk['conversation_id']
async def _agent_chat_messages( async def _agent_chat_messages(
self, query: core_entities.Query self, query: pipeline_query.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]: ) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""调用聊天助手""" """调用聊天助手"""
cov_id = query.session.using_conversation.uuid or '' cov_id = query.session.using_conversation.uuid or ''
@@ -244,7 +244,9 @@ class DifyServiceAPIRunner(runner.RequestRunner):
query.session.using_conversation.uuid = chunk['conversation_id'] query.session.using_conversation.uuid = chunk['conversation_id']
async def _workflow_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: async def _workflow_messages(
self, query: pipeline_query.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""调用工作流""" """调用工作流"""
if not query.session.using_conversation.uuid: if not query.session.using_conversation.uuid:
@@ -316,7 +318,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
yield msg yield msg
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求""" """运行请求"""
if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat': if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat':
async for msg in self._chat_messages(query): async for msg in self._chat_messages(query):

View File

@@ -4,15 +4,15 @@ import json
import typing import typing
from .. import runner from .. import runner
from ...core import entities as core_entities
from .. import entities as llm_entities from .. import entities as llm_entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@runner.runner_class('local-agent') @runner.runner_class('local-agent')
class LocalAgentRunner(runner.RequestRunner): class LocalAgentRunner(runner.RequestRunner):
"""本地Agent请求运行器""" """本地Agent请求运行器"""
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求""" """运行请求"""
pending_tool_calls = [] pending_tool_calls = []

View File

@@ -6,8 +6,9 @@ import uuid
import aiohttp import aiohttp
from .. import runner from .. import runner
from ...core import app, entities as core_entities from ...core import app
from .. import entities as llm_entities from .. import entities as llm_entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class N8nAPIError(Exception): class N8nAPIError(Exception):
@@ -49,7 +50,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
self.header_name = self.pipeline_config['ai']['n8n-service-api'].get('header-name', '') self.header_name = self.pipeline_config['ai']['n8n-service-api'].get('header-name', '')
self.header_value = self.pipeline_config['ai']['n8n-service-api'].get('header-value', '') self.header_value = self.pipeline_config['ai']['n8n-service-api'].get('header-value', '')
async def _preprocess_user_message(self, query: core_entities.Query) -> str: async def _preprocess_user_message(self, query: pipeline_query.Query) -> str:
"""预处理用户消息,提取纯文本 """预处理用户消息,提取纯文本
Returns: Returns:
@@ -67,7 +68,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
return plain_text return plain_text
async def _call_webhook(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""调用n8n webhook""" """调用n8n webhook"""
# 生成会话ID如果不存在 # 生成会话ID如果不存在
if not query.session.using_conversation.uuid: if not query.session.using_conversation.uuid:
@@ -153,7 +154,7 @@ class N8nServiceAPIRunner(runner.RequestRunner):
self.ap.logger.error(f'n8n webhook call exception: {str(e)}') self.ap.logger.error(f'n8n webhook call exception: {str(e)}')
raise N8nAPIError(f'n8n webhook call exception: {str(e)}') raise N8nAPIError(f'n8n webhook call exception: {str(e)}')
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求""" """运行请求"""
async for msg in self._call_webhook(query): async for msg in self._call_webhook(query):
yield msg yield msg

View File

@@ -2,9 +2,10 @@ from __future__ import annotations
import asyncio import asyncio
from ...core import app, entities as core_entities from ...core import app
from langbot_plugin.api.entities.builtin.provider import message as provider_message, prompt as provider_prompt from langbot_plugin.api.entities.builtin.provider import message as provider_message, prompt as provider_prompt
import langbot_plugin.api.entities.builtin.provider.session as provider_session import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class SessionManager: class SessionManager:
@@ -21,7 +22,7 @@ class SessionManager:
async def initialize(self): async def initialize(self):
pass pass
async def get_session(self, query: core_entities.Query) -> provider_session.Session: async def get_session(self, query: pipeline_query.Query) -> provider_session.Session:
"""获取会话""" """获取会话"""
for session in self.session_list: for session in self.session_list:
if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id: if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id:
@@ -39,7 +40,7 @@ class SessionManager:
async def get_conversation( async def get_conversation(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
session: provider_session.Session, session: provider_session.Session,
prompt_config: list[dict], prompt_config: list[dict],
pipeline_uuid: str, pipeline_uuid: str,

View File

@@ -3,8 +3,9 @@ from __future__ import annotations
import abc import abc
import typing import typing
from ...core import app, entities as core_entities from ...core import app
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_loaders: list[typing.Type[ToolLoader]] = [] preregistered_loaders: list[typing.Type[ToolLoader]] = []
@@ -45,7 +46,7 @@ class ToolLoader(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: async def invoke_tool(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any:
"""执行工具调用""" """执行工具调用"""
pass pass

View File

@@ -8,8 +8,9 @@ from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from .. import loader from .. import loader
from ....core import app, entities as core_entities from ....core import app
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class RuntimeMCPSession: class RuntimeMCPSession:
@@ -83,7 +84,7 @@ class RuntimeMCPSession:
for tool in tools.tools: for tool in tools.tools:
async def func(query: core_entities.Query, *, _tool=tool, **kwargs): async def func(query: pipeline_query.Query, *, _tool=tool, **kwargs):
result = await self.session.call_tool(_tool.name, kwargs) result = await self.session.call_tool(_tool.name, kwargs)
if result.isError: if result.isError:
raise Exception(result.content[0].text) raise Exception(result.content[0].text)
@@ -144,7 +145,7 @@ class MCPLoader(loader.ToolLoader):
async def has_tool(self, name: str) -> bool: async def has_tool(self, name: str) -> bool:
return name in [f.name for f in self._last_listed_functions] return name in [f.name for f in self._last_listed_functions]
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: async def invoke_tool(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any:
for server_name, session in self.sessions.items(): for server_name, session in self.sessions.items():
for function in session.functions: for function in session.functions:
if function.name == name: if function.name == name:

View File

@@ -4,9 +4,9 @@ import typing
import traceback import traceback
from .. import loader from .. import loader
from ....core import entities as core_entities
from ....plugin import context as plugin_context from ....plugin import context as plugin_context
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@loader.loader_class('plugin-tool-loader') @loader.loader_class('plugin-tool-loader')
@@ -49,7 +49,7 @@ class PluginToolLoader(loader.ToolLoader):
return function, plugin.plugin_inst return function, plugin.plugin_inst
return None, None return None, None
async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: async def invoke_tool(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any:
try: try:
function, plugin = await self._get_function_and_plugin(name) function, plugin = await self._get_function_and_plugin(name)
if function is None: if function is None:

View File

@@ -2,11 +2,12 @@ from __future__ import annotations
import typing import typing
from ...core import app, entities as core_entities from ...core import app
from . import loader as tools_loader from . import loader as tools_loader
from ...utils import importutil from ...utils import importutil
from . import loaders from . import loaders
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
importutil.import_modules_in_pkg(loaders) importutil.import_modules_in_pkg(loaders)
@@ -90,7 +91,7 @@ class ToolManager:
return tools return tools
async def execute_func_call(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: async def execute_func_call(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any:
"""执行函数调用""" """执行函数调用"""
for loader in self.loaders: for loader in self.loaders:

View File

@@ -6,7 +6,7 @@ import os
import base64 import base64
import logging import logging
import pydantic.v1 as pydantic import pydantic
import requests import requests
from ..core import app from ..core import app