diff --git a/main.py b/main.py index 19cb32d6..b7d62e07 100644 --- a/main.py +++ b/main.py @@ -47,13 +47,13 @@ async def main_entry(loop: asyncio.AbstractEventLoop): if not args.skip_plugin_deps_check: await deps.precheck_plugin_deps() - # 检查pydantic版本,如果没有 pydantic.v1,则把 pydantic 映射为 v1 - import pydantic.version + # # 检查pydantic版本,如果没有 pydantic.v1,则把 pydantic 映射为 v1 + # import pydantic.version - if pydantic.version.VERSION < '2.0': - import pydantic + # if pydantic.version.VERSION < '2.0': + # import pydantic - sys.modules['pydantic.v1'] = pydantic + # sys.modules['pydantic.v1'] = pydantic # 检查配置文件 diff --git a/pkg/api/http/controller/groups/system.py b/pkg/api/http/controller/groups/system.py index c4cab602..1089626d 100644 --- a/pkg/api/http/controller/groups/system.py +++ b/pkg/api/http/controller/groups/system.py @@ -35,15 +35,6 @@ class SystemRouterGroup(group.RouterGroup): return self.success(data=task.to_dict()) - @self.route('/reload', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) - async def _() -> str: - json_data = await quart.request.json - - scope = json_data.get('scope') - - await self.ap.reload(scope=scope) - return self.success() - @self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) async def _() -> str: if not constants.debug_mode: diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py index a8cf5eae..14c3d9e4 100644 --- a/pkg/command/cmdmgr.py +++ b/pkg/command/cmdmgr.py @@ -2,10 +2,11 @@ from __future__ import annotations import typing -from ..core import app, entities as core_entities +from ..core import app from . import entities, operator, errors from ..utils import importutil import langbot_plugin.api.entities.builtin.provider.session as provider_session +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query # 引入所有算子以便注册 from . import operators @@ -90,7 +91,7 @@ class CommandManager: async def execute( self, command_text: str, - query: core_entities.Query, + query: pipeline_query.Query, session: provider_session.Session, ) -> typing.AsyncGenerator[entities.CommandReturn, None]: """执行命令""" diff --git a/pkg/command/entities.py b/pkg/command/entities.py index e80d203f..7d6eecdc 100644 --- a/pkg/command/entities.py +++ b/pkg/command/entities.py @@ -2,12 +2,12 @@ from __future__ import annotations import typing -import pydantic.v1 as pydantic +import pydantic import langbot_plugin.api.entities.builtin.provider.session as provider_session -from ..core import entities as core_entities from . import errors from ..platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class CommandReturn(pydantic.BaseModel): @@ -35,7 +35,7 @@ class CommandReturn(pydantic.BaseModel): class ExecuteContext(pydantic.BaseModel): """单次命令执行上下文""" - query: core_entities.Query + query: pipeline_query.Query """本次消息的请求对象""" session: provider_session.Session diff --git a/pkg/core/app.py b/pkg/core/app.py index c795d6c0..4b3e3b82 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -3,7 +3,6 @@ from __future__ import annotations import logging import asyncio import traceback -import sys import os from ..platform import botmgr as im_mgr @@ -183,59 +182,3 @@ class Application: """.strip() for line in tips.split('\n'): self.logger.info(line) - - async def reload( - self, - scope: core_entities.LifecycleControlScope, - ): - match scope: - case core_entities.LifecycleControlScope.PLATFORM.value: - self.logger.info('执行热重载 scope=' + scope) - await self.platform_mgr.shutdown() - - self.platform_mgr = im_mgr.PlatformManager(self) - - await self.platform_mgr.initialize() - - self.task_mgr.create_task( - self.platform_mgr.run(), - name='platform-manager', - scopes=[ - core_entities.LifecycleControlScope.APPLICATION, - core_entities.LifecycleControlScope.PLATFORM, - ], - ) - case core_entities.LifecycleControlScope.PLUGIN.value: - self.logger.info('执行热重载 scope=' + scope) - await self.plugin_mgr.destroy_plugins() - - # 删除 sys.module 中所有的 plugins/* 下的模块 - for mod in list(sys.modules.keys()): - if mod.startswith('plugins.'): - del sys.modules[mod] - - self.plugin_mgr = plugin_mgr.PluginManager(self) - await self.plugin_mgr.initialize() - - await self.plugin_mgr.initialize_plugins() - - await self.plugin_mgr.load_plugins() - await self.plugin_mgr.initialize_plugins() - case core_entities.LifecycleControlScope.PROVIDER.value: - self.logger.info('执行热重载 scope=' + scope) - - await self.tool_mgr.shutdown() - - llm_model_mgr_inst = llm_model_mgr.ModelManager(self) - await llm_model_mgr_inst.initialize() - self.model_mgr = llm_model_mgr_inst - - llm_session_mgr_inst = llm_session_mgr.SessionManager(self) - await llm_session_mgr_inst.initialize() - self.sess_mgr = llm_session_mgr_inst - - llm_tool_mgr_inst = llm_tool_mgr.ToolManager(self) - await llm_tool_mgr_inst.initialize() - self.tool_mgr = llm_tool_mgr_inst - case _: - pass diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 3bc0349c..5abb7c74 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -3,7 +3,7 @@ from __future__ import annotations import enum import typing -import pydantic.v1 as pydantic +import pydantic from ..provider import entities as llm_entities from ..platform import adapter as msadapter @@ -20,23 +20,13 @@ class LifecycleControlScope(enum.Enum): PROVIDER = 'provider' -class LauncherTypes(enum.Enum): - """一个请求的发起者类型""" - - PERSON = 'person' - """私聊""" - - GROUP = 'group' - """群聊""" - - class Query(pydantic.BaseModel): """一次请求的信息封装""" query_id: int """请求ID,添加进请求池时生成""" - launcher_type: LauncherTypes + launcher_type: provider_session.LauncherTypes """会话类型,platform处理阶段设置""" launcher_id: typing.Union[int, str] diff --git a/pkg/pipeline/bansess/bansess.py b/pkg/pipeline/bansess/bansess.py index 3b927a55..0cd498f6 100644 --- a/pkg/pipeline/bansess/bansess.py +++ b/pkg/pipeline/bansess/bansess.py @@ -1,7 +1,7 @@ from __future__ import annotations from .. import stage, entities -from ...core import entities as core_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @stage.stage_class('BanSessionCheckStage') @@ -14,7 +14,7 @@ class BanSessionCheckStage(stage.PipelineStage): async def initialize(self, pipeline_config: dict): pass - async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult: found = False mode = query.pipeline_config['trigger']['access-control']['mode'] diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index fb562a42..1708363a 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -3,12 +3,11 @@ from __future__ import annotations from ...core import app from .. import stage, entities -from ...core import entities as core_entities from . import filter as filter_model, entities as filter_entities from langbot_plugin.api.entities.builtin.provider import message as provider_message from ...platform.types import message as platform_message from ...utils import importutil - +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query from . import filters importutil.import_modules_in_pkg(filters) @@ -58,7 +57,7 @@ class ContentFilterStage(stage.PipelineStage): async def _pre_process( self, message: str, - query: core_entities.Query, + query: pipeline_query.Query, ) -> entities.StageProcessResult: """请求llm前处理消息 只要有一个不通过就不放行,只放行 PASS 的消息 @@ -93,7 +92,7 @@ class ContentFilterStage(stage.PipelineStage): async def _post_process( self, message: str, - query: core_entities.Query, + query: pipeline_query.Query, ) -> entities.StageProcessResult: """请求llm后处理响应 只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter @@ -123,7 +122,7 @@ class ContentFilterStage(stage.PipelineStage): return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) - async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult: """处理""" if stage_inst_name == 'PreContentFilterStage': contain_non_text = False diff --git a/pkg/pipeline/cntfilter/entities.py b/pkg/pipeline/cntfilter/entities.py index 5e804c0d..607eba9a 100644 --- a/pkg/pipeline/cntfilter/entities.py +++ b/pkg/pipeline/cntfilter/entities.py @@ -1,6 +1,6 @@ import enum -import pydantic.v1 as pydantic +import pydantic class ResultLevel(enum.Enum): diff --git a/pkg/pipeline/cntfilter/filter.py b/pkg/pipeline/cntfilter/filter.py index 0a3ceaae..dafc539a 100644 --- a/pkg/pipeline/cntfilter/filter.py +++ b/pkg/pipeline/cntfilter/filter.py @@ -3,9 +3,9 @@ from __future__ import annotations import abc import typing -from ...core import app, entities as core_entities +from ...core import app from . import entities - +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query preregistered_filters: list[typing.Type[ContentFilter]] = [] @@ -60,7 +60,7 @@ class ContentFilter(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def process(self, query: core_entities.Query, message: str = None, image_url=None) -> entities.FilterResult: + async def process(self, query: pipeline_query.Query, message: str = None, image_url=None) -> entities.FilterResult: """处理消息 分为前后阶段,具体取决于 enable_stages 的值。 diff --git a/pkg/pipeline/cntfilter/filters/baiduexamine.py b/pkg/pipeline/cntfilter/filters/baiduexamine.py index 9637aec2..4213e662 100644 --- a/pkg/pipeline/cntfilter/filters/baiduexamine.py +++ b/pkg/pipeline/cntfilter/filters/baiduexamine.py @@ -4,8 +4,7 @@ import aiohttp from .. import entities from .. import filter as filter_model -from ....core import entities as core_entities - +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}' BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token' @@ -27,7 +26,7 @@ class BaiduCloudExamine(filter_model.ContentFilter): ) as resp: return (await resp.json())['access_token'] - async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: + async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult: async with aiohttp.ClientSession() as session: async with session.post( BAIDU_EXAMINE_URL.format(await self._get_token()), diff --git a/pkg/pipeline/cntfilter/filters/banwords.py b/pkg/pipeline/cntfilter/filters/banwords.py index 916a1bc1..e04de8c4 100644 --- a/pkg/pipeline/cntfilter/filters/banwords.py +++ b/pkg/pipeline/cntfilter/filters/banwords.py @@ -3,7 +3,7 @@ import re from .. import filter as filter_model from .. import entities -from ....core import entities as core_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @filter_model.filter_class('ban-word-filter') @@ -13,7 +13,7 @@ class BanWordFilter(filter_model.ContentFilter): async def initialize(self): pass - async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: + async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult: found = False for word in self.ap.sensitive_meta.data['words']: diff --git a/pkg/pipeline/cntfilter/filters/cntignore.py b/pkg/pipeline/cntfilter/filters/cntignore.py index 5e410e31..0a3ef709 100644 --- a/pkg/pipeline/cntfilter/filters/cntignore.py +++ b/pkg/pipeline/cntfilter/filters/cntignore.py @@ -3,7 +3,7 @@ import re from .. import entities from .. import filter as filter_model -from ....core import entities as core_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @filter_model.filter_class('content-ignore') @@ -16,7 +16,7 @@ class ContentIgnore(filter_model.ContentFilter): entities.EnableStage.PRE, ] - async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: + async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult: if 'prefix' in query.pipeline_config['trigger']['ignore-rules']: for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']: if message.startswith(rule): diff --git a/pkg/pipeline/controller.py b/pkg/pipeline/controller.py index 6679bd88..11bd8d46 100644 --- a/pkg/pipeline/controller.py +++ b/pkg/pipeline/controller.py @@ -3,7 +3,10 @@ from __future__ import annotations import asyncio import traceback -from ..core import app, entities +from ..core import app +from ..core import entities as core_entities + +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class Controller: @@ -22,11 +25,11 @@ class Controller: """事件处理循环""" try: while True: - selected_query: entities.Query = None + selected_query: pipeline_query.Query = None # 取请求 async with self.ap.query_pool: - queries: list[entities.Query] = self.ap.query_pool.queries + queries: list[pipeline_query.Query] = self.ap.query_pool.queries for query in queries: session = await self.ap.sess_mgr.get_session(query) @@ -46,7 +49,7 @@ class Controller: if selected_query: - async def _process_query(selected_query: entities.Query): + async def _process_query(selected_query: pipeline_query.Query): async with self.semaphore: # 总并发上限 # find pipeline # Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one. @@ -68,8 +71,8 @@ class Controller: kind='query', name=f'query-{selected_query.query_id}', scopes=[ - entities.LifecycleControlScope.APPLICATION, - entities.LifecycleControlScope.PLATFORM, + core_entities.LifecycleControlScope.APPLICATION, + core_entities.LifecycleControlScope.PLATFORM, ], ) diff --git a/pkg/pipeline/entities.py b/pkg/pipeline/entities.py index dd6434c0..7e7f23ce 100644 --- a/pkg/pipeline/entities.py +++ b/pkg/pipeline/entities.py @@ -3,10 +3,10 @@ from __future__ import annotations import enum import typing -import pydantic.v1 as pydantic +import pydantic from ..platform.types import message as platform_message -from ..core import entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class ResultType(enum.Enum): @@ -20,7 +20,7 @@ class ResultType(enum.Enum): class StageProcessResult(pydantic.BaseModel): result_type: ResultType - new_query: entities.Query + new_query: pipeline_query.Query user_notice: typing.Optional[ typing.Union[ diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index 5be20650..6356a16f 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -5,10 +5,9 @@ import traceback from . import strategy from .. import stage, entities -from ...core import entities as core_entities from ...platform.types import message as platform_message from ...utils import importutil - +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query from . import strategies importutil.import_modules_in_pkg(strategies) @@ -67,7 +66,7 @@ class LongTextProcessStage(stage.PipelineStage): await self.strategy_impl.initialize() - async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult: # 检查是否包含非 Plain 组件 contains_non_plain = False diff --git a/pkg/pipeline/longtext/strategies/forward.py b/pkg/pipeline/longtext/strategies/forward.py index 6228d580..574239b8 100644 --- a/pkg/pipeline/longtext/strategies/forward.py +++ b/pkg/pipeline/longtext/strategies/forward.py @@ -3,9 +3,9 @@ from __future__ import annotations from .. import strategy as strategy_model -from ....core import entities as core_entities -from ....platform.types import message as platform_message +from ....platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay Forward = platform_message.Forward @@ -13,7 +13,7 @@ Forward = platform_message.Forward @strategy_model.strategy_class('forward') class ForwardComponentStrategy(strategy_model.LongTextStrategy): - async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: + async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]: display = ForwardMessageDiaplay( title='群聊的聊天记录', brief='[聊天记录]', diff --git a/pkg/pipeline/longtext/strategies/image.py b/pkg/pipeline/longtext/strategies/image.py index f96f7265..ba6ddc1b 100644 --- a/pkg/pipeline/longtext/strategies/image.py +++ b/pkg/pipeline/longtext/strategies/image.py @@ -11,7 +11,7 @@ import functools from ....platform.types import message as platform_message from .. import strategy as strategy_model -from ....core import entities as core_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @strategy_model.strategy_class('image') @@ -27,7 +27,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): encoding='utf-8', ) - async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: + async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]: img_path = self.text_to_image( text_str=message, save_as='temp/{}.png'.format(int(time.time())), @@ -131,7 +131,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): text_str: str, save_as='temp.png', width=800, - query: core_entities.Query = None, + query: pipeline_query.Query = None, ): text_str = text_str.replace('\t', ' ') diff --git a/pkg/pipeline/longtext/strategy.py b/pkg/pipeline/longtext/strategy.py index 0ddec0c6..dd69b2bb 100644 --- a/pkg/pipeline/longtext/strategy.py +++ b/pkg/pipeline/longtext/strategy.py @@ -4,8 +4,8 @@ import typing from ...core import app -from ...core import entities as core_entities from ...platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query preregistered_strategies: list[typing.Type[LongTextStrategy]] = [] @@ -49,7 +49,7 @@ class LongTextStrategy(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: + async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]: """处理长文本 在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法 diff --git a/pkg/pipeline/msgtrun/msgtrun.py b/pkg/pipeline/msgtrun/msgtrun.py index c64f67fc..3acd7e5c 100644 --- a/pkg/pipeline/msgtrun/msgtrun.py +++ b/pkg/pipeline/msgtrun/msgtrun.py @@ -1,10 +1,9 @@ from __future__ import annotations from .. import stage, entities -from ...core import entities as core_entities from . import truncator from ...utils import importutil - +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query from . import truncators importutil.import_modules_in_pkg(truncators) @@ -29,7 +28,7 @@ class ConversationMessageTruncator(stage.PipelineStage): else: raise ValueError(f'未知的截断器: {use_method}') - async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult: """处理""" query = await self.trun.truncate(query) diff --git a/pkg/pipeline/msgtrun/truncator.py b/pkg/pipeline/msgtrun/truncator.py index 9e8b8a6c..180982d3 100644 --- a/pkg/pipeline/msgtrun/truncator.py +++ b/pkg/pipeline/msgtrun/truncator.py @@ -3,8 +3,8 @@ from __future__ import annotations import typing import abc -from ...core import entities as core_entities, app - +from ...core import app +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query preregistered_truncators: list[typing.Type[Truncator]] = [] @@ -47,7 +47,7 @@ class Truncator(abc.ABC): pass @abc.abstractmethod - async def truncate(self, query: core_entities.Query) -> core_entities.Query: + async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query: """截断 一般只需要操作query.messages,也可以扩展操作query.prompt, query.user_message。 diff --git a/pkg/pipeline/msgtrun/truncators/round.py b/pkg/pipeline/msgtrun/truncators/round.py index fa72a0e1..c6b1fba4 100644 --- a/pkg/pipeline/msgtrun/truncators/round.py +++ b/pkg/pipeline/msgtrun/truncators/round.py @@ -1,14 +1,14 @@ from __future__ import annotations from .. import truncator -from ....core import entities as core_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @truncator.truncator_class('round') class RoundTruncator(truncator.Truncator): """前文回合数阶段器""" - async def truncate(self, query: core_entities.Query) -> core_entities.Query: + async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query: """截断""" max_round = query.pipeline_config['ai']['local-agent']['max-round'] diff --git a/pkg/pipeline/pipelinemgr.py b/pkg/pipeline/pipelinemgr.py index 78cffa73..debdbb93 100644 --- a/pkg/pipeline/pipelinemgr.py +++ b/pkg/pipeline/pipelinemgr.py @@ -5,7 +5,7 @@ import traceback import sqlalchemy -from ..core import app, entities +from ..core import app from . import entities as pipeline_entities from ..entity.persistence import pipeline as persistence_pipeline from . import stage @@ -13,6 +13,9 @@ from ..platform.types import message as platform_message, events as platform_eve from ..plugin import events from ..utils import importutil +import langbot_plugin.api.entities.builtin.provider.session as provider_session +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query + from . import ( resprule, bansess, @@ -75,11 +78,11 @@ class RuntimePipeline: self.pipeline_entity = pipeline_entity self.stage_containers = stage_containers - async def run(self, query: entities.Query): + async def run(self, query: pipeline_query.Query): query.pipeline_config = self.pipeline_entity.config await self.process_query(query) - async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult): + async def _check_output(self, query: pipeline_query.Query, result: pipeline_entities.StageProcessResult): """检查输出""" if result.user_notice: # 处理str类型 @@ -109,7 +112,7 @@ class RuntimePipeline: async def _execute_from_stage( self, stage_index: int, - query: entities.Query, + query: pipeline_query.Query, ): """从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。 @@ -169,13 +172,13 @@ class RuntimePipeline: i += 1 - async def process_query(self, query: entities.Query): + async def process_query(self, query: pipeline_query.Query): """处理请求""" try: # ======== 触发 MessageReceived 事件 ======== event_type = ( events.PersonMessageReceived - if query.launcher_type == entities.LauncherTypes.PERSON + if query.launcher_type == provider_session.LauncherTypes.PERSON else events.GroupMessageReceived ) diff --git a/pkg/pipeline/pool.py b/pkg/pipeline/pool.py index 6975e53c..a4313cdd 100644 --- a/pkg/pipeline/pool.py +++ b/pkg/pipeline/pool.py @@ -3,10 +3,11 @@ from __future__ import annotations import asyncio import typing -from ..core import entities from ..platform import adapter as msadapter from ..platform.types import message as platform_message from ..platform.types import events as platform_events +import langbot_plugin.api.entities.builtin.provider.session as provider_session +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class QueryPool: @@ -16,7 +17,7 @@ class QueryPool: pool_lock: asyncio.Lock - queries: list[entities.Query] + queries: list[pipeline_query.Query] condition: asyncio.Condition @@ -29,16 +30,16 @@ class QueryPool: async def add_query( self, bot_uuid: str, - launcher_type: entities.LauncherTypes, + launcher_type: provider_session.LauncherTypes, launcher_id: typing.Union[int, str], sender_id: typing.Union[int, str], message_event: platform_events.MessageEvent, message_chain: platform_message.MessageChain, adapter: msadapter.MessagePlatformAdapter, pipeline_uuid: typing.Optional[str] = None, - ) -> entities.Query: + ) -> pipeline_query.Query: async with self.condition: - query = entities.Query( + query = pipeline_query.Query( bot_uuid=bot_uuid, query_id=self.query_id_counter, launcher_type=launcher_type, diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index da56ca6e..af851c96 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -3,10 +3,10 @@ from __future__ import annotations import datetime from .. import stage, entities -from ...core import entities as core_entities from langbot_plugin.api.entities.builtin.provider import message as provider_message from ...plugin import events from ...platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @stage.stage_class('PreProcessor') @@ -26,7 +26,7 @@ class PreProcessor(stage.PipelineStage): async def process( self, - query: core_entities.Query, + query: pipeline_query.Query, stage_inst_name: str, ) -> entities.StageProcessResult: """处理""" diff --git a/pkg/pipeline/process/handler.py b/pkg/pipeline/process/handler.py index 8a32bcfb..181d257d 100644 --- a/pkg/pipeline/process/handler.py +++ b/pkg/pipeline/process/handler.py @@ -3,8 +3,8 @@ from __future__ import annotations import abc from ...core import app -from ...core import entities as core_entities from .. import entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class MessageHandler(metaclass=abc.ABCMeta): @@ -19,7 +19,7 @@ class MessageHandler(metaclass=abc.ABCMeta): @abc.abstractmethod async def handle( self, - query: core_entities.Query, + query: pipeline_query.Query, ) -> entities.StageProcessResult: raise NotImplementedError diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 35fa1611..b871de81 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -6,13 +6,15 @@ import traceback from .. import handler from ... import entities -from ....core import entities as core_entities from ....provider import runner as runner_module from ....plugin import events from ....platform.types import message as platform_message from ....utils import importutil from ....provider import runners +import langbot_plugin.api.entities.builtin.provider.session as provider_session +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query + importutil.import_modules_in_pkg(runners) @@ -20,7 +22,7 @@ importutil.import_modules_in_pkg(runners) class ChatMessageHandler(handler.MessageHandler): async def handle( self, - query: core_entities.Query, + query: pipeline_query.Query, ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: """处理""" # 调API @@ -29,7 +31,7 @@ class ChatMessageHandler(handler.MessageHandler): # 触发插件事件 event_class = ( events.PersonNormalMessageReceived - if query.launcher_type == core_entities.LauncherTypes.PERSON + if query.launcher_type == provider_session.LauncherTypes.PERSON else events.GroupNormalMessageReceived ) diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py index efce5615..15c33ebd 100644 --- a/pkg/pipeline/process/handlers/command.py +++ b/pkg/pipeline/process/handlers/command.py @@ -4,16 +4,17 @@ import typing from .. import handler from ... import entities -from ....core import entities as core_entities from langbot_plugin.api.entities.builtin.provider import message as provider_message from ....plugin import events from ....platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.provider.session as provider_session +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class CommandHandler(handler.MessageHandler): async def handle( self, - query: core_entities.Query, + query: pipeline_query.Query, ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: """处理""" @@ -28,7 +29,7 @@ class CommandHandler(handler.MessageHandler): event_class = ( events.PersonCommandSent - if query.launcher_type == core_entities.LauncherTypes.PERSON + if query.launcher_type == provider_session.LauncherTypes.PERSON else events.GroupCommandSent ) diff --git a/pkg/pipeline/process/process.py b/pkg/pipeline/process/process.py index 64903552..704af5fd 100644 --- a/pkg/pipeline/process/process.py +++ b/pkg/pipeline/process/process.py @@ -1,10 +1,10 @@ from __future__ import annotations -from ...core import entities as core_entities from . import handler from .handlers import chat, command from .. import entities from .. import stage +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @stage.stage_class('MessageProcessor') @@ -30,7 +30,7 @@ class Processor(stage.PipelineStage): async def process( self, - query: core_entities.Query, + query: pipeline_query.Query, stage_inst_name: str, ) -> entities.StageProcessResult: """处理""" diff --git a/pkg/pipeline/ratelimit/algo.py b/pkg/pipeline/ratelimit/algo.py index 3bcc347a..efbc326b 100644 --- a/pkg/pipeline/ratelimit/algo.py +++ b/pkg/pipeline/ratelimit/algo.py @@ -2,7 +2,8 @@ from __future__ import annotations import abc import typing -from ...core import app, entities as core_entities +from ...core import app +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query preregistered_algos: list[typing.Type[ReteLimitAlgo]] = [] @@ -33,7 +34,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta): @abc.abstractmethod async def require_access( self, - query: core_entities.Query, + query: pipeline_query.Query, launcher_type: str, launcher_id: typing.Union[int, str], ) -> bool: @@ -53,7 +54,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta): @abc.abstractmethod async def release_access( self, - query: core_entities.Query, + query: pipeline_query.Query, launcher_type: str, launcher_id: typing.Union[int, str], ): diff --git a/pkg/pipeline/ratelimit/algos/fixedwin.py b/pkg/pipeline/ratelimit/algos/fixedwin.py index cc816f73..6a2a8e97 100644 --- a/pkg/pipeline/ratelimit/algos/fixedwin.py +++ b/pkg/pipeline/ratelimit/algos/fixedwin.py @@ -3,7 +3,7 @@ import asyncio import time import typing from .. import algo -from ....core import entities as core_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query # 固定窗口算法 @@ -32,7 +32,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo): async def require_access( self, - query: core_entities.Query, + query: pipeline_query.Query, launcher_type: str, launcher_id: typing.Union[int, str], ) -> bool: @@ -91,7 +91,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo): async def release_access( self, - query: core_entities.Query, + query: pipeline_query.Query, launcher_type: str, launcher_id: typing.Union[int, str], ): diff --git a/pkg/pipeline/ratelimit/ratelimit.py b/pkg/pipeline/ratelimit/ratelimit.py index 23de4ec6..cab62b8d 100644 --- a/pkg/pipeline/ratelimit/ratelimit.py +++ b/pkg/pipeline/ratelimit/ratelimit.py @@ -4,9 +4,10 @@ import typing from .. import entities, stage from . import algo -from ...core import entities as core_entities from ...utils import importutil +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query + from . import algos importutil.import_modules_in_pkg(algos) @@ -39,7 +40,7 @@ class RateLimit(stage.PipelineStage): async def process( self, - query: core_entities.Query, + query: pipeline_query.Query, stage_inst_name: str, ) -> typing.Union[ entities.StageProcessResult, diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index 39d3abb1..b5a1ed74 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -8,14 +8,14 @@ from ...platform.types import events as platform_events from ...platform.types import message as platform_message from .. import stage, entities -from ...core import entities as core_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @stage.stage_class('SendResponseBackStage') class SendResponseBackStage(stage.PipelineStage): """发送响应消息""" - async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult: """处理""" random_range = ( diff --git a/pkg/pipeline/resprule/entities.py b/pkg/pipeline/resprule/entities.py index a0ba7807..c2d964fe 100644 --- a/pkg/pipeline/resprule/entities.py +++ b/pkg/pipeline/resprule/entities.py @@ -1,4 +1,4 @@ -import pydantic.v1 as pydantic +import pydantic from ...platform.types import message as platform_message diff --git a/pkg/pipeline/resprule/resprule.py b/pkg/pipeline/resprule/resprule.py index 0193f2ce..1a3560ff 100644 --- a/pkg/pipeline/resprule/resprule.py +++ b/pkg/pipeline/resprule/resprule.py @@ -4,9 +4,10 @@ from __future__ import annotations from . import rule from .. import stage, entities -from ...core import entities as core_entities from ...utils import importutil +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query + from . import rules importutil.import_modules_in_pkg(rules) @@ -32,7 +33,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage): await rule_inst.initialize() self.rule_matchers.append(rule_inst) - async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: + async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult: if query.launcher_type.value != 'group': # 只处理群消息 return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) diff --git a/pkg/pipeline/resprule/rule.py b/pkg/pipeline/resprule/rule.py index 3fdb0386..7c91373f 100644 --- a/pkg/pipeline/resprule/rule.py +++ b/pkg/pipeline/resprule/rule.py @@ -2,10 +2,11 @@ from __future__ import annotations import abc import typing -from ...core import app, entities as core_entities +from ...core import app from . import entities from ...platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query preregisetered_rules: list[typing.Type[GroupRespondRule]] = [] @@ -39,7 +40,7 @@ class GroupRespondRule(metaclass=abc.ABCMeta): message_text: str, message_chain: platform_message.MessageChain, rule_dict: dict, - query: core_entities.Query, + query: pipeline_query.Query, ) -> entities.RuleJudgeResult: """判断消息是否匹配规则""" raise NotImplementedError diff --git a/pkg/pipeline/resprule/rules/atbot.py b/pkg/pipeline/resprule/rules/atbot.py index 340b92c7..fc3b5510 100644 --- a/pkg/pipeline/resprule/rules/atbot.py +++ b/pkg/pipeline/resprule/rules/atbot.py @@ -3,8 +3,8 @@ from __future__ import annotations from .. import rule as rule_model from .. import entities -from ....core import entities as core_entities from ....platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @rule_model.rule_class('at-bot') @@ -14,7 +14,7 @@ class AtBotRule(rule_model.GroupRespondRule): message_text: str, message_chain: platform_message.MessageChain, rule_dict: dict, - query: core_entities.Query, + query: pipeline_query.Query, ) -> entities.RuleJudgeResult: if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']: message_chain.remove(platform_message.At(query.adapter.bot_account_id)) diff --git a/pkg/pipeline/resprule/rules/prefix.py b/pkg/pipeline/resprule/rules/prefix.py index c712d3e8..2ae89fe1 100644 --- a/pkg/pipeline/resprule/rules/prefix.py +++ b/pkg/pipeline/resprule/rules/prefix.py @@ -1,7 +1,7 @@ from .. import rule as rule_model from .. import entities -from ....core import entities as core_entities from ....platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @rule_model.rule_class('prefix') @@ -11,7 +11,7 @@ class PrefixRule(rule_model.GroupRespondRule): message_text: str, message_chain: platform_message.MessageChain, rule_dict: dict, - query: core_entities.Query, + query: pipeline_query.Query, ) -> entities.RuleJudgeResult: prefixes = rule_dict['prefix'] diff --git a/pkg/pipeline/resprule/rules/random.py b/pkg/pipeline/resprule/rules/random.py index d2f782ab..04818ef0 100644 --- a/pkg/pipeline/resprule/rules/random.py +++ b/pkg/pipeline/resprule/rules/random.py @@ -3,8 +3,8 @@ import random from .. import rule as rule_model from .. import entities -from ....core import entities as core_entities from ....platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @rule_model.rule_class('random') @@ -14,7 +14,7 @@ class RandomRespRule(rule_model.GroupRespondRule): message_text: str, message_chain: platform_message.MessageChain, rule_dict: dict, - query: core_entities.Query, + query: pipeline_query.Query, ) -> entities.RuleJudgeResult: random_rate = rule_dict['random'] diff --git a/pkg/pipeline/resprule/rules/regexp.py b/pkg/pipeline/resprule/rules/regexp.py index daac0869..51589e0c 100644 --- a/pkg/pipeline/resprule/rules/regexp.py +++ b/pkg/pipeline/resprule/rules/regexp.py @@ -3,8 +3,8 @@ import re from .. import rule as rule_model from .. import entities -from ....core import entities as core_entities from ....platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @rule_model.rule_class('regexp') @@ -14,7 +14,7 @@ class RegExpRule(rule_model.GroupRespondRule): message_text: str, message_chain: platform_message.MessageChain, rule_dict: dict, - query: core_entities.Query, + query: pipeline_query.Query, ) -> entities.RuleJudgeResult: regexps = rule_dict['regexp'] diff --git a/pkg/pipeline/stage.py b/pkg/pipeline/stage.py index 18a94b73..0ff1af7e 100644 --- a/pkg/pipeline/stage.py +++ b/pkg/pipeline/stage.py @@ -3,8 +3,9 @@ from __future__ import annotations import abc import typing -from ..core import app, entities as core_entities +from ..core import app from . import entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query preregistered_stages: dict[str, type[PipelineStage]] = {} @@ -33,7 +34,7 @@ class PipelineStage(metaclass=abc.ABCMeta): @abc.abstractmethod async def process( self, - query: core_entities.Query, + query: pipeline_query.Query, stage_inst_name: str, ) -> typing.Union[ entities.StageProcessResult, diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py index 3299a226..8063ff36 100644 --- a/pkg/pipeline/wrapper/wrapper.py +++ b/pkg/pipeline/wrapper/wrapper.py @@ -2,12 +2,11 @@ from __future__ import annotations import typing - -from ...core import entities as core_entities from .. import entities from .. import stage from ...plugin import events from ...platform.types import message as platform_message +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @stage.stage_class('ResponseWrapper') @@ -25,7 +24,7 @@ class ResponseWrapper(stage.PipelineStage): async def process( self, - query: core_entities.Query, + query: pipeline_query.Query, stage_inst_name: str, ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: """处理""" diff --git a/pkg/platform/adapter.py b/pkg/platform/adapter.py index f28ad3dc..f27efc75 100644 --- a/pkg/platform/adapter.py +++ b/pkg/platform/adapter.py @@ -3,15 +3,14 @@ from __future__ import annotations # MessageSource的适配器 import typing import abc +import pydantic - -from ..core import app from .types import message as platform_message from .types import events as platform_events from .logger import EventLogger -class MessagePlatformAdapter(metaclass=abc.ABCMeta): +class MessagePlatformAdapter(pydantic.BaseModel, metaclass=abc.ABCMeta): """消息平台适配器基类""" name: str @@ -21,11 +20,9 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): config: dict - ap: app.Application + logger: EventLogger = pydantic.Field(exclude=True) - logger: EventLogger - - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): + def __init__(self, config: dict, logger: EventLogger): """初始化适配器 Args: @@ -33,7 +30,6 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): ap (app.Application): 应用上下文 """ self.config = config - self.ap = ap self.logger = logger async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): diff --git a/pkg/platform/botmgr.py b/pkg/platform/botmgr.py index 5855525f..8f247ca4 100644 --- a/pkg/platform/botmgr.py +++ b/pkg/platform/botmgr.py @@ -19,6 +19,8 @@ from ..entity.errors import platform as platform_errors from .logger import EventLogger +import langbot_plugin.api.entities.builtin.provider.session as provider_session + # 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题 from . import types as mirai @@ -73,7 +75,7 @@ class RuntimeBot: await self.ap.query_pool.add_query( bot_uuid=self.bot_entity.uuid, - launcher_type=core_entities.LauncherTypes.PERSON, + launcher_type=provider_session.LauncherTypes.PERSON, launcher_id=event.sender.id, sender_id=event.sender.id, message_event=event, @@ -98,7 +100,7 @@ class RuntimeBot: await self.ap.query_pool.add_query( bot_uuid=self.bot_entity.uuid, - launcher_type=core_entities.LauncherTypes.GROUP, + launcher_type=provider_session.LauncherTypes.GROUP, launcher_id=event.group.id, sender_id=event.sender.id, message_event=event, @@ -172,9 +174,9 @@ class PlatformManager: webchat_logger = EventLogger(name='webchat-adapter', ap=self.ap) webchat_adapter_inst = webchat_adapter_class( {}, - self.ap, webchat_logger, ) + webchat_adapter_inst.ap = self.ap self.webchat_proxy_bot = RuntimeBot( ap=self.ap, @@ -231,7 +233,6 @@ class PlatformManager: adapter_inst = self.adapter_dict[bot_entity.adapter]( bot_entity.adapter_config, - self.ap, logger, ) diff --git a/pkg/platform/sources/aiocqhttp.py b/pkg/platform/sources/aiocqhttp.py index 8cdfd204..b2616bb0 100644 --- a/pkg/platform/sources/aiocqhttp.py +++ b/pkg/platform/sources/aiocqhttp.py @@ -7,7 +7,6 @@ import datetime import aiocqhttp from .. import adapter -from ...core import app from ..types import message as platform_message from ..types import events as platform_events from ..types import entities as platform_entities @@ -273,11 +272,9 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): config: dict - ap: app.Application - on_websocket_connection_event_cache: typing.List[typing.Callable[[aiocqhttp.Event], None]] = [] - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): + def __init__(self, config: dict, logger: EventLogger): self.config = config self.logger = logger @@ -287,7 +284,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): self.config['shutdown_trigger'] = shutdown_trigger_placeholder - self.ap = ap self.on_websocket_connection_event_cache = [] if 'access-token' in config: diff --git a/pkg/platform/sources/dingtalk.py b/pkg/platform/sources/dingtalk.py index 3147c984..1727a771 100644 --- a/pkg/platform/sources/dingtalk.py +++ b/pkg/platform/sources/dingtalk.py @@ -4,7 +4,6 @@ from libs.dingtalk_api.dingtalkevent import DingTalkEvent from pkg.platform.types import message as platform_message from pkg.platform.adapter import MessagePlatformAdapter from .. import adapter -from ...core import app from ..types import events as platform_events from ..types import entities as platform_entities from libs.dingtalk_api.api import DingTalkClient @@ -94,15 +93,13 @@ class DingTalkEventConverter(adapter.EventConverter): class DingTalkAdapter(adapter.MessagePlatformAdapter): bot: DingTalkClient - ap: app.Application bot_account_id: str message_converter: DingTalkMessageConverter = DingTalkMessageConverter() event_converter: DingTalkEventConverter = DingTalkEventConverter() config: dict - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): + def __init__(self, config: dict, logger: EventLogger): self.config = config - self.ap = ap self.logger = logger required_keys = [ 'client_id', diff --git a/pkg/platform/sources/discord.py b/pkg/platform/sources/discord.py index f159c628..52bd5e5b 100644 --- a/pkg/platform/sources/discord.py +++ b/pkg/platform/sources/discord.py @@ -12,7 +12,6 @@ import datetime import aiohttp from .. import adapter -from ...core import app from ..types import message as platform_message from ..types import events as platform_events from ..types import entities as platform_entities @@ -161,8 +160,6 @@ class DiscordAdapter(adapter.MessagePlatformAdapter): config: dict - ap: app.Application - message_converter: DiscordMessageConverter = DiscordMessageConverter() event_converter: DiscordEventConverter = DiscordEventConverter() @@ -171,9 +168,8 @@ class DiscordAdapter(adapter.MessagePlatformAdapter): typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ] = {} - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): + def __init__(self, config: dict, logger: EventLogger): self.config = config - self.ap = ap self.logger = logger self.bot_account_id = self.config['client_id'] diff --git a/pkg/platform/sources/lark.py b/pkg/platform/sources/lark.py index f8faf522..9e727ad3 100644 --- a/pkg/platform/sources/lark.py +++ b/pkg/platform/sources/lark.py @@ -19,7 +19,6 @@ import quart from lark_oapi.api.im.v1 import * from .. import adapter -from ...core import app from ..types import message as platform_message from ..types import events as platform_events from ..types import entities as platform_entities @@ -337,11 +336,9 @@ class LarkAdapter(adapter.MessagePlatformAdapter): config: dict quart_app: quart.Quart - ap: app.Application - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): + def __init__(self, config: dict, logger: EventLogger): self.config = config - self.ap = ap self.logger = logger self.quart_app = quart.Quart(__name__) self.listeners = {} @@ -351,8 +348,6 @@ class LarkAdapter(adapter.MessagePlatformAdapter): try: data = await quart.request.json - self.ap.logger.debug(f'Lark callback event: {data}') - if 'encrypt' in data: cipher = AESCipher(self.config['encrypt-key']) data = cipher.decrypt_string(data['encrypt']) diff --git a/pkg/platform/sources/gewechat.png b/pkg/platform/sources/legacy/gewechat.png similarity index 100% rename from pkg/platform/sources/gewechat.png rename to pkg/platform/sources/legacy/gewechat.png diff --git a/pkg/platform/sources/gewechat.py b/pkg/platform/sources/legacy/gewechat.py similarity index 98% rename from pkg/platform/sources/gewechat.py rename to pkg/platform/sources/legacy/gewechat.py index 01d9f946..7e7b7715 100644 --- a/pkg/platform/sources/gewechat.py +++ b/pkg/platform/sources/legacy/gewechat.py @@ -11,16 +11,16 @@ import threading import quart import aiohttp -from .. import adapter -from ...core import app -from ..types import message as platform_message -from ..types import events as platform_events -from ..types import entities as platform_entities -from ...utils import image +from ... import adapter +from ....core import app +from ...types import message as platform_message +from ...types import events as platform_events +from ...types import entities as platform_entities +from ....utils import image import xml.etree.ElementTree as ET from typing import Optional, Tuple from functools import partial -from ..logger import EventLogger +from ...logger import EventLogger class GewechatMessageConverter(adapter.MessageConverter): @@ -491,7 +491,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): async def gewechat_callback(): data = await quart.request.json # print(json.dumps(data, indent=4, ensure_ascii=False)) - self.ap.logger.debug(f'Gewechat callback event: {data}') + await self.logger.debug(f'Gewechat callback event: {data}') if 'data' in data: data['Data'] = data['data'] @@ -601,7 +601,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): if handler := handler_map.get(msg['type']): handler(msg) else: - self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}') + await self.logger.warning(f'未处理的消息类型: {msg["type"]}') continue async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): @@ -656,9 +656,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter): self.config['app_id'] = app_id - self.ap.logger.info(f'Gewechat 登录成功,app_id: {app_id}') - - self.ap.platform_mgr.write_back_config('gewechat', self, self.config) + print(f'Gewechat 登录成功,app_id: {app_id}') # 获取 nickname profile = self.bot.get_profile(self.config['app_id']) diff --git a/pkg/platform/sources/gewechat.yaml b/pkg/platform/sources/legacy/gewechat.yaml similarity index 100% rename from pkg/platform/sources/gewechat.yaml rename to pkg/platform/sources/legacy/gewechat.yaml diff --git a/pkg/platform/sources/nakuru.png b/pkg/platform/sources/legacy/nakuru.png similarity index 100% rename from pkg/platform/sources/nakuru.png rename to pkg/platform/sources/legacy/nakuru.png diff --git a/pkg/platform/sources/nakuru.py b/pkg/platform/sources/legacy/nakuru.py similarity index 97% rename from pkg/platform/sources/nakuru.py rename to pkg/platform/sources/legacy/nakuru.py index 16ad54db..5afb6356 100644 --- a/pkg/platform/sources/nakuru.py +++ b/pkg/platform/sources/legacy/nakuru.py @@ -9,12 +9,12 @@ import traceback import nakuru import nakuru.entities.components as nkc -from .. import adapter as adapter_model -from ...pipeline.longtext.strategies import forward -from ...platform.types import message as platform_message -from ...platform.types import entities as platform_entities -from ...platform.types import events as platform_events -from ..logger import EventLogger +from ... import adapter as adapter_model +from ....pipeline.longtext.strategies import forward +from ...types import message as platform_message +from ...types import entities as platform_entities +from ...types import events as platform_events +from ...logger import EventLogger class NakuruProjectMessageConverter(adapter_model.MessageConverter): @@ -262,7 +262,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): source_cls = NakuruProjectEventConverter.yiri2target(event_type) # 包装函数 - async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls): + async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls): # type: ignore await callback(self.event_converter.target2yiri(source), self) # 将包装函数和原函数的对应关系存入列表 @@ -322,7 +322,6 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter): except Exception: raise Exception('获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确') await self.bot._run() - self.ap.logger.info('运行 Nakuru 适配器') while True: await asyncio.sleep(1) diff --git a/pkg/platform/sources/nakuru.yaml b/pkg/platform/sources/legacy/nakuru.yaml similarity index 100% rename from pkg/platform/sources/nakuru.yaml rename to pkg/platform/sources/legacy/nakuru.yaml diff --git a/pkg/platform/sources/qqbotpy.py b/pkg/platform/sources/legacy/qqbotpy.py similarity index 97% rename from pkg/platform/sources/qqbotpy.py rename to pkg/platform/sources/legacy/qqbotpy.py index d4a4d526..7e8fb125 100644 --- a/pkg/platform/sources/qqbotpy.py +++ b/pkg/platform/sources/legacy/qqbotpy.py @@ -10,14 +10,14 @@ import botpy import botpy.message as botpy_message import botpy.types.message as botpy_message_type -from .. import adapter as adapter_model -from ...pipeline.longtext.strategies import forward -from ...core import app -from ...config import manager as cfg_mgr -from ...platform.types import entities as platform_entities -from ...platform.types import events as platform_events -from ...platform.types import message as platform_message -from ..logger import EventLogger +from ... import adapter as adapter_model +from ....pipeline.longtext.strategies import forward +from ....core import app +from ....config import manager as cfg_mgr +from ...types import entities as platform_entities +from ...types import events as platform_events +from ...types import message as platform_message +from ...logger import EventLogger class OfficialGroupMessage(platform_events.GroupMessage): @@ -519,7 +519,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter): self.cfg['ret_coro'] = True - self.ap.logger.info('运行 QQ 官方适配器') + await self.logger.info('运行 QQ 官方适配器') await (await self.bot.start(**self.cfg)) async def kill(self) -> bool: diff --git a/pkg/platform/sources/qqbotpy.svg b/pkg/platform/sources/legacy/qqbotpy.svg similarity index 100% rename from pkg/platform/sources/qqbotpy.svg rename to pkg/platform/sources/legacy/qqbotpy.svg diff --git a/pkg/platform/sources/qqbotpy.yaml b/pkg/platform/sources/legacy/qqbotpy.yaml similarity index 100% rename from pkg/platform/sources/qqbotpy.yaml rename to pkg/platform/sources/legacy/qqbotpy.yaml diff --git a/pkg/platform/sources/officialaccount.py b/pkg/platform/sources/officialaccount.py index 3fc1e393..925b0ee4 100644 --- a/pkg/platform/sources/officialaccount.py +++ b/pkg/platform/sources/officialaccount.py @@ -10,7 +10,6 @@ from libs.official_account_api.oaevent import OAEvent from libs.official_account_api.api import OAClient from libs.official_account_api.api import OAClientForLongerResponse from .. import adapter -from ...core import app from ..types import entities as platform_entities from ...command.errors import ParamNotEnoughError from ..logger import EventLogger @@ -58,15 +57,13 @@ class OAEventConverter(adapter.EventConverter): class OfficialAccountAdapter(adapter.MessagePlatformAdapter): bot: OAClient | OAClientForLongerResponse - ap: app.Application bot_account_id: str message_converter: OAMessageConverter = OAMessageConverter() event_converter: OAEventConverter = OAEventConverter() config: dict - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): + def __init__(self, config: dict, logger: EventLogger): self.config = config - self.ap = ap self.logger = logger required_keys = [ diff --git a/pkg/platform/sources/qqofficial.py b/pkg/platform/sources/qqofficial.py index 63ab531f..cd7beb31 100644 --- a/pkg/platform/sources/qqofficial.py +++ b/pkg/platform/sources/qqofficial.py @@ -8,7 +8,6 @@ import datetime from pkg.platform.adapter import MessagePlatformAdapter from pkg.platform.types import events as platform_events, message as platform_message from .. import adapter -from ...core import app from ..types import entities as platform_entities from ...command.errors import ParamNotEnoughError from libs.qq_official_api.api import QQOfficialClient @@ -134,15 +133,13 @@ class QQOfficialEventConverter(adapter.EventConverter): class QQOfficialAdapter(adapter.MessagePlatformAdapter): bot: QQOfficialClient - ap: app.Application config: dict bot_account_id: str message_converter: QQOfficialMessageConverter = QQOfficialMessageConverter() event_converter: QQOfficialEventConverter = QQOfficialEventConverter() - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): + def __init__(self, config: dict, logger: EventLogger): self.config = config - self.ap = ap self.logger = logger required_keys = [ diff --git a/pkg/platform/sources/slack.py b/pkg/platform/sources/slack.py index 1bd5aa2d..ff14ce1c 100644 --- a/pkg/platform/sources/slack.py +++ b/pkg/platform/sources/slack.py @@ -9,7 +9,6 @@ from libs.slack_api.api import SlackClient from pkg.platform.adapter import MessagePlatformAdapter from pkg.platform.types import events as platform_events, message as platform_message from libs.slack_api.slackevent import SlackEvent -from pkg.core import app from .. import adapter from ..types import entities as platform_entities from ...command.errors import ParamNotEnoughError @@ -86,15 +85,13 @@ class SlackEventConverter(adapter.EventConverter): class SlackAdapter(adapter.MessagePlatformAdapter): bot: SlackClient - ap: app.Application bot_account_id: str message_converter: SlackMessageConverter = SlackMessageConverter() event_converter: SlackEventConverter = SlackEventConverter() config: dict - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): + def __init__(self, config: dict, logger: EventLogger): self.config = config - self.ap = ap self.logger = logger required_keys = [ 'bot_token', diff --git a/pkg/platform/sources/telegram.py b/pkg/platform/sources/telegram.py index c2fcc22e..52d79853 100644 --- a/pkg/platform/sources/telegram.py +++ b/pkg/platform/sources/telegram.py @@ -10,10 +10,7 @@ import traceback import base64 import aiohttp -from lark_oapi.api.im.v1 import * - from .. import adapter -from ...core import app from ..types import message as platform_message from ..types import events as platform_events from ..types import entities as platform_entities @@ -141,16 +138,14 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): event_converter: TelegramEventConverter = TelegramEventConverter() config: dict - ap: app.Application listeners: typing.Dict[ typing.Type[platform_events.Event], typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], ] = {} - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): + def __init__(self, config: dict, logger: EventLogger): self.config = config - self.ap = ap self.logger = logger async def telegram_callback(update: Update, context: ContextTypes.DEFAULT_TYPE): diff --git a/pkg/platform/sources/webchat.py b/pkg/platform/sources/webchat.py index 51b0479f..0a35c1ac 100644 --- a/pkg/platform/sources/webchat.py +++ b/pkg/platform/sources/webchat.py @@ -44,13 +44,14 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter): webchat_person_session: WebChatSession webchat_group_session: WebChatSession + ap: app.Application # set by bot manager + listeners: typing.Dict[ typing.Type[platform_events.Event], typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], None], ] = {} - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): - self.ap = ap + def __init__(self, config: dict, logger: EventLogger): self.logger = logger self.config = config diff --git a/pkg/platform/sources/wechatpad.py b/pkg/platform/sources/wechatpad.py index 88ec9bd9..0188d788 100644 --- a/pkg/platform/sources/wechatpad.py +++ b/pkg/platform/sources/wechatpad.py @@ -488,6 +488,8 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): ap: app.Application + logger: EventLogger + message_converter: WeChatPadMessageConverter event_converter: WeChatPadEventConverter @@ -507,8 +509,6 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): async def ws_message(self, data): """处理接收到的消息""" - # self.ap.logger.debug(f"Gewechat callback event: {data}") - # print(data) try: event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id) @@ -571,9 +571,8 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): if handler := handler_map.get(msg['type']): handler(msg) - # self.ap.logger.warning(f"未处理的消息类型: {ret}") else: - self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}') + print(f'未处理的消息类型: {msg["type"]}') continue async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): @@ -615,7 +614,6 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): if self.config['token']: self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token']) data = self.bot.get_login_status() - self.ap.logger.info(data) if data['Code'] == 300 and data['Text'] == '你已退出微信': response = requests.post( f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}', @@ -635,7 +633,7 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): self.config['token'] = response.json()['Data'][0] self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'], logger=self.logger) - self.ap.logger.info(self.config['token']) + await self.logger.info(self.config['token']) thread_1 = threading.Event() def wechat_login_process(): @@ -643,10 +641,9 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): # login_data =self.bot.get_login_qr() # url = login_data['Data']["QrCodeUrl"] - # self.ap.logger.info(login_data) profile = self.bot.get_profile() - self.ap.logger.info(profile) + self.logger.info(profile) self.bot_account_id = profile['Data']['userInfo']['nickName']['str'] self.config['wxid'] = profile['Data']['userInfo']['userName']['str'] @@ -658,27 +655,26 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): def connect_websocket_sync() -> None: thread_1.wait() uri = f'{self.config["wechatpad_ws"]}/GetSyncMsg?key={self.config["token"]}' - self.ap.logger.info(f'Connecting to WebSocket: {uri}') + print(f'Connecting to WebSocket: {uri}') def on_message(ws, message): try: data = json.loads(message) - self.ap.logger.debug(f'Received message: {data}') # 这里需要确保ws_message是同步的,或者使用asyncio.run调用异步方法 asyncio.run(self.ws_message(data)) except json.JSONDecodeError: - self.ap.logger.error(f'Non-JSON message: {message[:100]}...') + print(f'Non-JSON message: {message[:100]}...') def on_error(ws, error): - self.ap.logger.error(f'WebSocket error: {str(error)[:200]}') + print(f'WebSocket error: {str(error)[:200]}') def on_close(ws, close_status_code, close_msg): - self.ap.logger.info('WebSocket closed, reconnecting...') + print('WebSocket closed, reconnecting...') time.sleep(5) connect_websocket_sync() # 自动重连 def on_open(ws): - self.ap.logger.info('WebSocket connected successfully!') + print('WebSocket connected successfully!') ws = websocket.WebSocketApp( uri, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open @@ -689,10 +685,9 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter): # connect_websocket_sync() # 这行代码会在WebSocket连接断开后才会执行 - # self.ap.logger.info("WebSocket client thread started") thread = threading.Thread(target=connect_websocket_sync, name='WebSocketClientThread', daemon=True) thread.start() - self.ap.logger.info('WebSocket client thread started') + self.logger.info('WebSocket client thread started') async def kill(self) -> bool: pass diff --git a/pkg/platform/sources/wecom.py b/pkg/platform/sources/wecom.py index 7be05a85..7bb0a757 100644 --- a/pkg/platform/sources/wecom.py +++ b/pkg/platform/sources/wecom.py @@ -10,7 +10,6 @@ from pkg.platform.adapter import MessagePlatformAdapter from pkg.platform.types import events as platform_events, message as platform_message from libs.wecom_api.wecomevent import WecomEvent from .. import adapter -from ...core import app from ..types import entities as platform_entities from ...command.errors import ParamNotEnoughError from ...utils import image @@ -129,15 +128,13 @@ class WecomEventConverter: class WecomAdapter(adapter.MessagePlatformAdapter): bot: WecomClient - ap: app.Application bot_account_id: str message_converter: WecomMessageConverter = WecomMessageConverter() event_converter: WecomEventConverter = WecomEventConverter() config: dict - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): + def __init__(self, config: dict, logger: EventLogger): self.config = config - self.ap = ap self.logger = logger required_keys = [ diff --git a/pkg/platform/sources/wecomcs.py b/pkg/platform/sources/wecomcs.py index da84ac6d..fcd5378e 100644 --- a/pkg/platform/sources/wecomcs.py +++ b/pkg/platform/sources/wecomcs.py @@ -9,7 +9,6 @@ from libs.wecom_customer_service_api.api import WecomCSClient from pkg.platform.adapter import MessagePlatformAdapter from pkg.platform.types import events as platform_events, message as platform_message from libs.wecom_customer_service_api.wecomcsevent import WecomCSEvent -from pkg.core import app from .. import adapter from ..types import entities as platform_entities from ...command.errors import ParamNotEnoughError @@ -119,15 +118,13 @@ class WecomEventConverter: class WecomCSAdapter(adapter.MessagePlatformAdapter): bot: WecomCSClient - ap: app.Application bot_account_id: str message_converter: WecomMessageConverter = WecomMessageConverter() event_converter: WecomEventConverter = WecomEventConverter() config: dict - def __init__(self, config: dict, ap: app.Application, logger: EventLogger): + def __init__(self, config: dict, logger: EventLogger): self.config = config - self.ap = ap self.logger = logger required_keys = [ diff --git a/pkg/plugin/events.py b/pkg/plugin/events.py index 777b61d6..e6e2dccb 100644 --- a/pkg/plugin/events.py +++ b/pkg/plugin/events.py @@ -4,16 +4,16 @@ import typing import pydantic.v1 as pydantic -from ..core import entities as core_entities from ..provider import entities as llm_entities from ..platform.types import message as platform_message import langbot_plugin.api.entities.builtin.provider.session as provider_session +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class BaseEventModel(pydantic.BaseModel): """事件模型基类""" - query: typing.Union[core_entities.Query, None] + query: typing.Union[pipeline_query.Query, None] """此次请求的query对象,非请求过程的事件时为None""" class Config: diff --git a/pkg/plugin/loaders/classic.py b/pkg/plugin/loaders/classic.py index c94b0d7d..6613bb63 100644 --- a/pkg/plugin/loaders/classic.py +++ b/pkg/plugin/loaders/classic.py @@ -6,10 +6,10 @@ import importlib import traceback from .. import loader, events, context, models -from ...core import entities as core_entities from langbot_plugin.api.entities.builtin.resource import tool as resource_tool from ...utils import funcschema from ...discover import engine as discover_engine +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class PluginLoader(loader.PluginLoader): @@ -98,7 +98,7 @@ class PluginLoader(loader.PluginLoader): function_schema = funcschema.get_func_schema(func) function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name) - async def handler(plugin: context.BasePlugin, query: core_entities.Query, *args, **kwargs): + async def handler(plugin: context.BasePlugin, query: pipeline_query.Query, *args, **kwargs): return func(*args, **kwargs) llm_function = resource_tool.LLMTool( diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index 94b812d9..1f38ca01 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -1,7 +1,7 @@ from __future__ import annotations import typing -import pydantic.v1 as pydantic +import pydantic from pkg.provider import entities diff --git a/pkg/provider/modelmgr/entities.py b/pkg/provider/modelmgr/entities.py index cf856894..91d1d6e9 100644 --- a/pkg/provider/modelmgr/entities.py +++ b/pkg/provider/modelmgr/entities.py @@ -2,7 +2,7 @@ from __future__ import annotations import typing -import pydantic.v1 as pydantic +import pydantic from . import requester from . import token diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 4008ca16..b8443b2c 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -4,11 +4,11 @@ import abc import typing from ...core import app -from ...core import entities as core_entities from .. import entities as llm_entities from ...entity.persistence import model as persistence_model import langbot_plugin.api.entities.builtin.resource.tool as resource_tool from . import token +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class RuntimeLLMModel: @@ -56,7 +56,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): @abc.abstractmethod async def invoke_llm( self, - query: core_entities.Query, + query: pipeline_query.Query, model: RuntimeLLMModel, messages: typing.List[llm_entities.Message], funcs: typing.List[resource_tool.LLMTool] = None, diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.py b/pkg/provider/modelmgr/requesters/anthropicmsgs.py index 4655b3e0..1a100ca3 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.py +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.py @@ -9,10 +9,10 @@ import httpx from .. import errors, requester -from ....core import entities as core_entities from ... import entities as llm_entities from ....utils import image import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class AnthropicMessages(requester.LLMAPIRequester): @@ -48,7 +48,7 @@ class AnthropicMessages(requester.LLMAPIRequester): async def invoke_llm( self, - query: core_entities.Query, + query: pipeline_query.Query, model: requester.RuntimeLLMModel, messages: typing.List[llm_entities.Message], funcs: typing.List[resource_tool.LLMTool] = None, diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index 00ff0a41..944e0eef 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -8,9 +8,9 @@ import openai.types.chat.chat_completion as chat_completion import httpx from .. import errors, requester -from ....core import entities as core_entities from ... import entities as llm_entities import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class OpenAIChatCompletions(requester.LLMAPIRequester): @@ -60,7 +60,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): async def _closure( self, - query: core_entities.Query, + query: pipeline_query.Query, req_messages: list[dict], use_model: requester.RuntimeLLMModel, use_funcs: list[resource_tool.LLMTool] = None, @@ -101,7 +101,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): async def invoke_llm( self, - query: core_entities.Query, + query: pipeline_query.Query, model: requester.RuntimeLLMModel, messages: typing.List[llm_entities.Message], funcs: typing.List[resource_tool.LLMTool] = None, diff --git a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py index 6dced3c9..ecf7a697 100644 --- a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py @@ -4,9 +4,9 @@ import typing from . import chatcmpl from .. import errors, requester -from ....core import entities as core_entities from ... import entities as llm_entities import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): @@ -19,7 +19,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): async def _closure( self, - query: core_entities.Query, + query: pipeline_query.Query, req_messages: list[dict], use_model: requester.RuntimeLLMModel, use_funcs: list[resource_tool.LLMTool] = None, diff --git a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py index 26da7d6d..9828e2ca 100644 --- a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py @@ -5,9 +5,9 @@ import typing from . import chatcmpl from .. import requester -from ....core import entities as core_entities from ... import entities as llm_entities import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): @@ -20,7 +20,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): async def _closure( self, - query: core_entities.Query, + query: pipeline_query.Query, req_messages: list[dict], use_model: requester.RuntimeLLMModel, use_funcs: list[resource_tool.LLMTool] = None, diff --git a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py index e46d102e..68eb7399 100644 --- a/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py +++ b/pkg/provider/modelmgr/requesters/modelscopechatcmpl.py @@ -9,9 +9,9 @@ import openai.types.chat.chat_completion_message_tool_call as chat_completion_me import httpx from .. import entities, errors, requester -from ....core import entities as core_entities from ... import entities as llm_entities import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class ModelScopeChatCompletions(requester.LLMAPIRequester): @@ -125,7 +125,7 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester): async def _closure( self, - query: core_entities.Query, + query: pipeline_query.Query, req_messages: list[dict], use_model: requester.RuntimeLLMModel, use_funcs: list[resource_tool.LLMTool] = None, @@ -166,7 +166,7 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester): async def invoke_llm( self, - query: core_entities.Query, + query: pipeline_query.Query, model: entities.LLMModelInfo, messages: typing.List[llm_entities.Message], funcs: typing.List[resource_tool.LLMTool] = None, diff --git a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py index e5019426..20c3427c 100644 --- a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py @@ -5,9 +5,9 @@ import typing from . import chatcmpl from .. import requester -from ....core import entities as core_entities from ... import entities as llm_entities import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): @@ -20,7 +20,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): async def _closure( self, - query: core_entities.Query, + query: pipeline_query.Query, req_messages: list[dict], use_model: requester.RuntimeLLMModel, use_funcs: list[resource_tool.LLMTool] = None, diff --git a/pkg/provider/modelmgr/requesters/ollamachat.py b/pkg/provider/modelmgr/requesters/ollamachat.py index 2afe34b3..b22895a6 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.py +++ b/pkg/provider/modelmgr/requesters/ollamachat.py @@ -12,7 +12,7 @@ import ollama from .. import errors, requester from ... import entities as llm_entities import langbot_plugin.api.entities.builtin.resource.tool as resource_tool -from ....core import entities as core_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query REQUESTER_NAME: str = 'ollama-chat' @@ -39,7 +39,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester): async def _closure( self, - query: core_entities.Query, + query: pipeline_query.Query, req_messages: list[dict], use_model: requester.RuntimeLLMModel, use_funcs: list[resource_tool.LLMTool] = None, @@ -105,7 +105,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester): async def invoke_llm( self, - query: core_entities.Query, + query: pipeline_query.Query, model: requester.RuntimeLLMModel, messages: typing.List[llm_entities.Message], funcs: typing.List[resource_tool.LLMTool] = None, diff --git a/pkg/provider/runner.py b/pkg/provider/runner.py index a74a2dc5..42f702f8 100644 --- a/pkg/provider/runner.py +++ b/pkg/provider/runner.py @@ -3,8 +3,9 @@ from __future__ import annotations import abc import typing -from ..core import app, entities as core_entities +from ..core import app from . import entities as llm_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query preregistered_runners: list[typing.Type[RequestRunner]] = [] @@ -35,6 +36,6 @@ class RequestRunner(abc.ABC): self.pipeline_config = pipeline_config @abc.abstractmethod - async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行请求""" pass diff --git a/pkg/provider/runners/dashscopeapi.py b/pkg/provider/runners/dashscopeapi.py index 02cb0b51..7c71d6b3 100644 --- a/pkg/provider/runners/dashscopeapi.py +++ b/pkg/provider/runners/dashscopeapi.py @@ -6,8 +6,9 @@ import re import dashscope from .. import runner -from ...core import app, entities as core_entities +from ...core import app from .. import entities as llm_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class DashscopeAPIError(Exception): @@ -65,7 +66,7 @@ class DashScopeAPIRunner(runner.RequestRunner): # 使用 re.sub() 进行替换 return pattern.sub(replacement, text) - async def _preprocess_user_message(self, query: core_entities.Query) -> tuple[str, list[str]]: + async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[str]]: """预处理用户消息,提取纯文本,阿里云提供的上传文件方法过于复杂,暂不支持上传文件(包括图片)""" plain_text = '' image_ids = [] @@ -89,7 +90,7 @@ class DashScopeAPIRunner(runner.RequestRunner): return plain_text, image_ids - async def _agent_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def _agent_messages(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """Dashscope 智能体对话请求""" # 局部变量 @@ -147,7 +148,9 @@ class DashScopeAPIRunner(runner.RequestRunner): content=pending_content, ) - async def _workflow_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def _workflow_messages( + self, query: pipeline_query.Query + ) -> typing.AsyncGenerator[llm_entities.Message, None]: """Dashscope 工作流对话请求""" # 局部变量 @@ -210,7 +213,7 @@ class DashScopeAPIRunner(runner.RequestRunner): content=pending_content, ) - async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行""" if self.app_type == 'agent': async for msg in self._agent_messages(query): diff --git a/pkg/provider/runners/difysvapi.py b/pkg/provider/runners/difysvapi.py index b2542491..c5819de3 100644 --- a/pkg/provider/runners/difysvapi.py +++ b/pkg/provider/runners/difysvapi.py @@ -8,10 +8,10 @@ import base64 from .. import runner -from ...core import app, entities as core_entities +from ...core import app from .. import entities as llm_entities from ...utils import image - +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query from libs.dify_service_api.v1 import client, errors @@ -62,7 +62,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL) return f'{thinking_text.group(1)}\n{content_text}' - async def _preprocess_user_message(self, query: core_entities.Query) -> tuple[str, list[str]]: + async def _preprocess_user_message(self, query: pipeline_query.Query) -> tuple[str, list[str]]: """预处理用户消息,提取纯文本,并将图片上传到 Dify 服务 Returns: @@ -90,7 +90,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): return plain_text, image_ids - async def _chat_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def _chat_messages(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """调用聊天助手""" cov_id = query.session.using_conversation.uuid or '' query.variables['conversation_id'] = cov_id @@ -152,7 +152,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): query.session.using_conversation.uuid = chunk['conversation_id'] async def _agent_chat_messages( - self, query: core_entities.Query + self, query: pipeline_query.Query ) -> typing.AsyncGenerator[llm_entities.Message, None]: """调用聊天助手""" cov_id = query.session.using_conversation.uuid or '' @@ -244,7 +244,9 @@ class DifyServiceAPIRunner(runner.RequestRunner): query.session.using_conversation.uuid = chunk['conversation_id'] - async def _workflow_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def _workflow_messages( + self, query: pipeline_query.Query + ) -> typing.AsyncGenerator[llm_entities.Message, None]: """调用工作流""" if not query.session.using_conversation.uuid: @@ -316,7 +318,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): yield msg - async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行请求""" if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat': async for msg in self._chat_messages(query): diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index e87ee81d..5a879bcb 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -4,15 +4,15 @@ import json import typing from .. import runner -from ...core import entities as core_entities from .. import entities as llm_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @runner.runner_class('local-agent') class LocalAgentRunner(runner.RequestRunner): """本地Agent请求运行器""" - async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行请求""" pending_tool_calls = [] diff --git a/pkg/provider/runners/n8nsvapi.py b/pkg/provider/runners/n8nsvapi.py index 7044cce1..37567d15 100644 --- a/pkg/provider/runners/n8nsvapi.py +++ b/pkg/provider/runners/n8nsvapi.py @@ -6,8 +6,9 @@ import uuid import aiohttp from .. import runner -from ...core import app, entities as core_entities +from ...core import app from .. import entities as llm_entities +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class N8nAPIError(Exception): @@ -49,7 +50,7 @@ class N8nServiceAPIRunner(runner.RequestRunner): self.header_name = self.pipeline_config['ai']['n8n-service-api'].get('header-name', '') self.header_value = self.pipeline_config['ai']['n8n-service-api'].get('header-value', '') - async def _preprocess_user_message(self, query: core_entities.Query) -> str: + async def _preprocess_user_message(self, query: pipeline_query.Query) -> str: """预处理用户消息,提取纯文本 Returns: @@ -67,7 +68,7 @@ class N8nServiceAPIRunner(runner.RequestRunner): return plain_text - async def _call_webhook(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def _call_webhook(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """调用n8n webhook""" # 生成会话ID(如果不存在) if not query.session.using_conversation.uuid: @@ -153,7 +154,7 @@ class N8nServiceAPIRunner(runner.RequestRunner): self.ap.logger.error(f'n8n webhook call exception: {str(e)}') raise N8nAPIError(f'n8n webhook call exception: {str(e)}') - async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + async def run(self, query: pipeline_query.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行请求""" async for msg in self._call_webhook(query): yield msg diff --git a/pkg/provider/session/sessionmgr.py b/pkg/provider/session/sessionmgr.py index 500ab49c..03465e0b 100644 --- a/pkg/provider/session/sessionmgr.py +++ b/pkg/provider/session/sessionmgr.py @@ -2,9 +2,10 @@ from __future__ import annotations import asyncio -from ...core import app, entities as core_entities +from ...core import app from langbot_plugin.api.entities.builtin.provider import message as provider_message, prompt as provider_prompt import langbot_plugin.api.entities.builtin.provider.session as provider_session +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class SessionManager: @@ -21,7 +22,7 @@ class SessionManager: async def initialize(self): pass - async def get_session(self, query: core_entities.Query) -> provider_session.Session: + async def get_session(self, query: pipeline_query.Query) -> provider_session.Session: """获取会话""" for session in self.session_list: if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id: @@ -39,7 +40,7 @@ class SessionManager: async def get_conversation( self, - query: core_entities.Query, + query: pipeline_query.Query, session: provider_session.Session, prompt_config: list[dict], pipeline_uuid: str, diff --git a/pkg/provider/tools/loader.py b/pkg/provider/tools/loader.py index fca9aa93..658fdeb6 100644 --- a/pkg/provider/tools/loader.py +++ b/pkg/provider/tools/loader.py @@ -3,8 +3,9 @@ from __future__ import annotations import abc import typing -from ...core import app, entities as core_entities +from ...core import app import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query preregistered_loaders: list[typing.Type[ToolLoader]] = [] @@ -45,7 +46,7 @@ class ToolLoader(abc.ABC): pass @abc.abstractmethod - async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: + async def invoke_tool(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any: """执行工具调用""" pass diff --git a/pkg/provider/tools/loaders/mcp.py b/pkg/provider/tools/loaders/mcp.py index bf35990e..577c704e 100644 --- a/pkg/provider/tools/loaders/mcp.py +++ b/pkg/provider/tools/loaders/mcp.py @@ -8,8 +8,9 @@ from mcp.client.stdio import stdio_client from mcp.client.sse import sse_client from .. import loader -from ....core import app, entities as core_entities +from ....core import app import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query class RuntimeMCPSession: @@ -83,7 +84,7 @@ class RuntimeMCPSession: for tool in tools.tools: - async def func(query: core_entities.Query, *, _tool=tool, **kwargs): + async def func(query: pipeline_query.Query, *, _tool=tool, **kwargs): result = await self.session.call_tool(_tool.name, kwargs) if result.isError: raise Exception(result.content[0].text) @@ -144,7 +145,7 @@ class MCPLoader(loader.ToolLoader): async def has_tool(self, name: str) -> bool: return name in [f.name for f in self._last_listed_functions] - async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: + async def invoke_tool(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any: for server_name, session in self.sessions.items(): for function in session.functions: if function.name == name: diff --git a/pkg/provider/tools/loaders/plugin.py b/pkg/provider/tools/loaders/plugin.py index c6ecda7d..7dfaea97 100644 --- a/pkg/provider/tools/loaders/plugin.py +++ b/pkg/provider/tools/loaders/plugin.py @@ -4,9 +4,9 @@ import typing import traceback from .. import loader -from ....core import entities as core_entities from ....plugin import context as plugin_context import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query @loader.loader_class('plugin-tool-loader') @@ -49,7 +49,7 @@ class PluginToolLoader(loader.ToolLoader): return function, plugin.plugin_inst return None, None - async def invoke_tool(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: + async def invoke_tool(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any: try: function, plugin = await self._get_function_and_plugin(name) if function is None: diff --git a/pkg/provider/tools/toolmgr.py b/pkg/provider/tools/toolmgr.py index 5f0cbdbf..e1105750 100644 --- a/pkg/provider/tools/toolmgr.py +++ b/pkg/provider/tools/toolmgr.py @@ -2,11 +2,12 @@ from __future__ import annotations import typing -from ...core import app, entities as core_entities +from ...core import app from . import loader as tools_loader from ...utils import importutil from . import loaders import langbot_plugin.api.entities.builtin.resource.tool as resource_tool +import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query importutil.import_modules_in_pkg(loaders) @@ -90,7 +91,7 @@ class ToolManager: return tools - async def execute_func_call(self, query: core_entities.Query, name: str, parameters: dict) -> typing.Any: + async def execute_func_call(self, query: pipeline_query.Query, name: str, parameters: dict) -> typing.Any: """执行函数调用""" for loader in self.loaders: diff --git a/pkg/utils/announce.py b/pkg/utils/announce.py index 7108a08c..56de579d 100644 --- a/pkg/utils/announce.py +++ b/pkg/utils/announce.py @@ -6,7 +6,7 @@ import os import base64 import logging -import pydantic.v1 as pydantic +import pydantic import requests from ..core import app