From 9f15ab50009df2a87ea87122f52d9fbc8b6c1b1c Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Sat, 29 Mar 2025 17:50:45 +0800 Subject: [PATCH] feat: preliminarily implement pipeline invoking --- pkg/api/http/service/pipeline.py | 19 +++- pkg/core/app.py | 13 +-- pkg/core/entities.py | 16 ++- pkg/core/stages/build_app.py | 16 +-- pkg/core/taskmgr.py | 7 ++ pkg/pipeline/bansess/bansess.py | 4 +- pkg/pipeline/cntfilter/cntfilter.py | 4 +- pkg/pipeline/controller.py | 10 +- pkg/pipeline/longtext/longtext.py | 12 +- pkg/pipeline/longtext/strategies/image.py | 16 ++- pkg/pipeline/msgtrun/msgtrun.py | 4 +- pkg/pipeline/pipelinemgr.py | 54 +++++++-- pkg/pipeline/pool.py | 4 +- pkg/pipeline/preproc/preproc.py | 15 +-- pkg/pipeline/process/handlers/chat.py | 21 ++-- pkg/pipeline/process/process.py | 4 +- pkg/pipeline/ratelimit/ratelimit.py | 4 +- pkg/pipeline/respback/respback.py | 26 +++-- pkg/pipeline/resprule/resprule.py | 4 +- pkg/pipeline/stage.py | 2 +- pkg/pipeline/stagemgr.py | 71 ------------ pkg/pipeline/wrapper/wrapper.py | 6 +- pkg/platform/manager.py | 74 +++++------- pkg/platform/sources/aiocqhttp.py | 2 + pkg/plugin/context.py | 8 +- pkg/provider/modelmgr/modelmgr.py | 34 ++---- pkg/provider/modelmgr/requester.py | 37 +++--- .../modelmgr/requesters/anthropicmsgs.py | 15 +-- .../modelmgr/requesters/anthropicmsgs.yaml | 2 +- .../modelmgr/requesters/bailianchatcmpl.py | 2 +- .../modelmgr/requesters/bailianchatcmpl.yaml | 2 +- pkg/provider/modelmgr/requesters/chatcmpl.py | 16 ++- .../modelmgr/requesters/chatcmpl.yaml | 2 +- .../modelmgr/requesters/deepseekchatcmpl.py | 8 +- .../modelmgr/requesters/deepseekchatcmpl.yaml | 2 +- .../modelmgr/requesters/giteeaichatcmpl.py | 8 +- .../modelmgr/requesters/giteeaichatcmpl.yaml | 2 +- .../modelmgr/requesters/lmstudiochatcmpl.py | 2 +- .../modelmgr/requesters/lmstudiochatcmpl.yaml | 2 +- .../modelmgr/requesters/moonshotchatcmpl.py | 8 +- .../modelmgr/requesters/moonshotchatcmpl.yaml | 2 +- .../modelmgr/requesters/ollamachat.py | 107 ++++++++++-------- .../modelmgr/requesters/ollamachat.yaml | 2 +- .../requesters/siliconflowchatcmpl.py | 2 +- .../requesters/siliconflowchatcmpl.yaml | 2 +- .../modelmgr/requesters/volcarkchatcmpl.py | 2 +- .../modelmgr/requesters/volcarkchatcmpl.yaml | 2 +- .../modelmgr/requesters/xaichatcmpl.py | 2 +- .../modelmgr/requesters/xaichatcmpl.yaml | 2 +- .../modelmgr/requesters/zhipuaichatcmpl.py | 2 +- .../modelmgr/requesters/zhipuaichatcmpl.yaml | 2 +- pkg/provider/runner.py | 8 +- pkg/provider/runnermgr.py | 30 ----- pkg/provider/runners/dashscopeapi.py | 26 ++--- pkg/provider/runners/difysvapi.py | 44 ++++--- pkg/provider/runners/localagent.py | 6 +- pkg/provider/session/sessionmgr.py | 6 +- 57 files changed, 384 insertions(+), 421 deletions(-) delete mode 100644 pkg/pipeline/stagemgr.py delete mode 100644 pkg/provider/runnermgr.py diff --git a/pkg/api/http/service/pipeline.py b/pkg/api/http/service/pipeline.py index 7920c4c9..f1bcaa75 100644 --- a/pkg/api/http/service/pipeline.py +++ b/pkg/api/http/service/pipeline.py @@ -5,10 +5,25 @@ import datetime import sqlalchemy from ....core import app -from ....pipeline import stagemgr from ....entity.persistence import pipeline as persistence_pipeline +default_stage_order = [ + "GroupRespondRuleCheckStage", # 群响应规则检查 + "BanSessionCheckStage", # 封禁会话检查 + "PreContentFilterStage", # 内容过滤前置阶段 + "PreProcessor", # 预处理器 + "ConversationMessageTruncator", # 会话消息截断器 + "RequireRateLimitOccupancy", # 请求速率限制占用 + "MessageProcessor", # 处理器 + "ReleaseRateLimitOccupancy", # 释放速率限制占用 + "PostContentFilterStage", # 内容过滤后置阶段 + "ResponseWrapper", # 响应包装器 + "LongTextProcessStage", # 长文本处理 + "SendResponseBackStage", # 发送响应 +] + + class PipelineService: ap: app.Application @@ -49,7 +64,7 @@ class PipelineService: async def create_pipeline(self, pipeline_data: dict) -> str: pipeline_data['uuid'] = str(uuid.uuid4()) pipeline_data['for_version'] = self.ap.ver_mgr.get_current_version() - pipeline_data['stages'] = stagemgr.stage_order.copy() + pipeline_data['stages'] = default_stage_order.copy() # TODO: 检查pipeline config是否完整 diff --git a/pkg/core/app.py b/pkg/core/app.py index 0191cc02..126b165a 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -13,14 +13,13 @@ from ..provider.session import sessionmgr as llm_session_mgr from ..provider.modelmgr import modelmgr as llm_model_mgr from ..provider.sysprompt import sysprompt as llm_prompt_mgr from ..provider.tools import toolmgr as llm_tool_mgr -from ..provider import runnermgr from ..config import manager as config_mgr from ..config import settings as settings_mgr from ..audit.center import v2 as center_mgr from ..command import cmdmgr from ..plugin import manager as plugin_mgr from ..pipeline import pool -from ..pipeline import controller, stagemgr, pipelinemgr +from ..pipeline import controller, pipelinemgr from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr from ..persistence import mgr as persistencemgr from ..api.http.controller import main as http_controller @@ -53,12 +52,12 @@ class Application: model_mgr: llm_model_mgr.ModelManager = None + # TODO 移动到 pipeline 里 prompt_mgr: llm_prompt_mgr.PromptManager = None + # TODO 移动到 pipeline 里 tool_mgr: llm_tool_mgr.ToolManager = None - runner_mgr: runnermgr.RunnerManager = None - settings_mgr: settings_mgr.SettingsManager = None # ======= 配置管理器 ======= @@ -100,8 +99,6 @@ class Application: ctrl: controller.Controller = None - stage_mgr: stagemgr.StageManager = None - pipeline_mgr: pipelinemgr.PipelineManager = None ver_mgr: version_mgr.VersionManager = None @@ -239,9 +236,5 @@ class Application: llm_tool_mgr_inst = llm_tool_mgr.ToolManager(self) await llm_tool_mgr_inst.initialize() self.tool_mgr = llm_tool_mgr_inst - - runner_mgr_inst = runnermgr.RunnerManager(self) - await runner_mgr_inst.initialize() - self.runner_mgr = runner_mgr_inst case _: pass \ No newline at end of file diff --git a/pkg/core/entities.py b/pkg/core/entities.py index a34eb082..d8768494 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -8,7 +8,7 @@ import asyncio import pydantic.v1 as pydantic from ..provider import entities as llm_entities -from ..provider.modelmgr import entities +from ..provider.modelmgr import entities, modelmgr, requester from ..provider.sysprompt import entities as sysprompt_entities from ..provider.tools import entities as tools_entities from ..platform import adapter as msadapter @@ -57,6 +57,9 @@ class Query(pydantic.BaseModel): message_chain: platform_message.MessageChain """消息链,platform收到的原始消息链""" + bot_uuid: typing.Optional[str] = None + """机器人UUID。""" + pipeline_uuid: typing.Optional[str] = None """流水线UUID。""" @@ -81,8 +84,8 @@ class Query(pydantic.BaseModel): variables: typing.Optional[dict[str, typing.Any]] = None """变量,由前置处理器阶段设置。在prompt中嵌入或由 Runner 传递到 LLMOps 平台。""" - use_model: typing.Optional[entities.LLMModelInfo] = None - """使用的模型,由前置处理器阶段设置""" + use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None + """使用的对话模型,由前置处理器阶段设置""" use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None """使用的函数,由前置处理器阶段设置""" @@ -94,7 +97,7 @@ class Query(pydantic.BaseModel): """回复消息链,从resp_messages包装而得""" # ======= 内部保留 ======= - current_stage: "pkg.pipeline.stagemgr.StageInstContainer" = None + current_stage: "pkg.pipeline.pipelinemgr.StageInstContainer" = None """当前所处阶段""" class Config: @@ -132,13 +135,16 @@ class Conversation(pydantic.BaseModel): update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) - use_model: entities.LLMModelInfo + use_llm_model: requester.RuntimeLLMModel use_funcs: typing.Optional[list[tools_entities.LLMFunction]] uuid: typing.Optional[str] = None """该对话的 uuid,在创建时不会自动生成。而是当使用 Dify API 等由外部管理对话信息的服务时,用于绑定外部的会话。具体如何使用,取决于 Runner。""" + class Config: + arbitrary_types_allowed = True + class Session(pydantic.BaseModel): """会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}""" diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index 0bd0d8a5..fcd930a3 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -6,14 +6,13 @@ from .. import stage, app from ...utils import version, proxy, announce, platform from ...audit.center import v2 as center_v2 from ...audit import identifier -from ...pipeline import pool, controller, stagemgr, pipelinemgr +from ...pipeline import pool, controller, pipelinemgr from ...plugin import manager as plugin_mgr from ...command import cmdmgr from ...provider.session import sessionmgr as llm_session_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.sysprompt import sysprompt as llm_prompt_mgr from ...provider.tools import toolmgr as llm_tool_mgr -from ...provider import runnermgr from ...platform import manager as im_mgr from ...persistence import mgr as persistencemgr from ...api.http.controller import main as http_controller @@ -61,10 +60,7 @@ class BuildAppStage(stage.BootingStage): }, runtime_info={ "admin_id": "{}".format(ap.system_cfg.data["admin-sessions"]), - "msg_source": str([ - adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown' - for adapter_cfg in ap.platform_cfg.data['platform-adapters'] if adapter_cfg['enable'] - ]), + "msg_source": str([]), }, ) ap.ctr_mgr = center_v2_api @@ -107,18 +103,10 @@ class BuildAppStage(stage.BootingStage): await llm_tool_mgr_inst.initialize() ap.tool_mgr = llm_tool_mgr_inst - runner_mgr_inst = runnermgr.RunnerManager(ap) - await runner_mgr_inst.initialize() - ap.runner_mgr = runner_mgr_inst - im_mgr_inst = im_mgr.PlatformManager(ap=ap) await im_mgr_inst.initialize() ap.platform_mgr = im_mgr_inst - stage_mgr = stagemgr.StageManager(ap) - await stage_mgr.initialize() - ap.stage_mgr = stage_mgr - pipeline_mgr = pipelinemgr.PipelineManager(ap) await pipeline_mgr.initialize() ap.pipeline_mgr = pipeline_mgr diff --git a/pkg/core/taskmgr.py b/pkg/core/taskmgr.py index 2c029c03..d5887019 100644 --- a/pkg/core/taskmgr.py +++ b/pkg/core/taskmgr.py @@ -233,3 +233,10 @@ class AsyncTaskManager: if not wrapper.task.done() and scope in wrapper.scopes: wrapper.task.cancel() + + def cancel_task(self, task_id: int): + for wrapper in self.tasks: + if wrapper.id == task_id: + if not wrapper.task.done(): + wrapper.task.cancel() + return diff --git a/pkg/pipeline/bansess/bansess.py b/pkg/pipeline/bansess/bansess.py index 9c041385..1ca42397 100644 --- a/pkg/pipeline/bansess/bansess.py +++ b/pkg/pipeline/bansess/bansess.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from .. import stage, entities, stagemgr +from .. import stage, entities from ...core import entities as core_entities from ...config import manager as cfg_mgr @@ -13,7 +13,7 @@ class BanSessionCheckStage(stage.PipelineStage): 仅检查query中群号或个人号是否在访问控制列表中。 """ - async def initialize(self): + async def initialize(self, pipeline_config: dict): pass async def process( diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index f7376b61..6a0c3776 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -2,7 +2,7 @@ from __future__ import annotations from ...core import app -from .. import stage, entities, stagemgr +from .. import stage, entities from ...core import entities as core_entities from ...config import manager as cfg_mgr from . import filter as filter_model, entities as filter_entities @@ -35,7 +35,7 @@ class ContentFilterStage(stage.PipelineStage): self.filter_chain = [] super().__init__(ap) - async def initialize(self): + async def initialize(self, pipeline_config: dict): filters_required = [ "content-ignore", diff --git a/pkg/pipeline/controller.py b/pkg/pipeline/controller.py index 5d66f49e..64d4e8f4 100644 --- a/pkg/pipeline/controller.py +++ b/pkg/pipeline/controller.py @@ -54,9 +54,13 @@ class Controller: async def _process_query(selected_query: entities.Query): async with self.semaphore: # 总并发上限 # find pipeline - pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(selected_query.pipeline_uuid) - if pipeline: - await pipeline.run(selected_query) + # Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one. + # Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected. + bot = await self.ap.platform_mgr.get_bot_by_uuid(selected_query.bot_uuid) + if bot: + pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(bot.bot_entity.use_pipeline_uuid) + if pipeline: + await pipeline.run(selected_query) async with self.ap.query_pool: (await self.ap.sess_mgr.get_session(selected_query)).semaphore.release() diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index ecb745d0..ac03ad42 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -7,7 +7,7 @@ from PIL import Image, ImageDraw, ImageFont from ...core import app from . import strategy from .strategies import image, forward -from .. import stage, entities, stagemgr +from .. import stage, entities from ...core import entities as core_entities from ...config import manager as cfg_mgr from ...platform.types import message as platform_message @@ -23,8 +23,8 @@ class LongTextProcessStage(stage.PipelineStage): strategy_impl: strategy.LongTextStrategy - async def initialize(self): - config = self.ap.platform_cfg.data['long-text-process'] + async def initialize(self, pipeline_config: dict): + config = pipeline_config['output']['long-text-processing'] if config['strategy'] == 'image': use_font = config['font-path'] try: @@ -42,12 +42,12 @@ class LongTextProcessStage(stage.PipelineStage): else: self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。") - self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward" + pipeline_config['output']['long-text-processing']['strategy'] = "forward" except: traceback.print_exc() self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。".format(use_font)) - self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward" + pipeline_config['output']['long-text-processing']['strategy'] = "forward" for strategy_cls in strategy.preregistered_strategies: if strategy_cls.name == config['strategy']: @@ -69,7 +69,7 @@ class LongTextProcessStage(stage.PipelineStage): if contains_non_plain: self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。") - elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']: + elif len(str(query.resp_message_chain[-1])) > query.pipeline_config['output']['long-text-processing']['threshold']: query.resp_message_chain[-1] = platform_message.MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query)) return entities.StageProcessResult( diff --git a/pkg/pipeline/longtext/strategies/image.py b/pkg/pipeline/longtext/strategies/image.py index b9675074..b30d3a81 100644 --- a/pkg/pipeline/longtext/strategies/image.py +++ b/pkg/pipeline/longtext/strategies/image.py @@ -8,6 +8,7 @@ import re from PIL import Image, ImageDraw, ImageFont +import functools from ....platform.types import message as platform_message from .. import strategy as strategy_model @@ -17,15 +18,18 @@ from ....core import entities as core_entities @strategy_model.strategy_class("image") class Text2ImageStrategy(strategy_model.LongTextStrategy): - text_render_font: ImageFont.FreeTypeFont - async def initialize(self): - self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8") + pass + + @functools.lru_cache(maxsize=16) + def get_font(self, query: core_entities.Query): + return ImageFont.truetype(query.pipeline_config['output']['long-text-processing']['font-path'], 32, encoding="utf-8") async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: img_path = self.text_to_image( text_str=message, - save_as='temp/{}.png'.format(int(time.time())) + save_as='temp/{}.png'.format(int(time.time())), + query=query ) compressed_path, size = self.compress_image( @@ -127,7 +131,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): return outfile, self.get_size(outfile) - def text_to_image(self, text_str: str, save_as="temp.png", width=800): + def text_to_image(self, text_str: str, save_as="temp.png", width=800, query: core_entities.Query = None): text_str = text_str.replace("\t", " ") @@ -142,7 +146,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width)) for line in lines: # 如果长了就分割 - line_width = self.text_render_font.getlength(line) + line_width = self.get_font(query).getlength(line) self.ap.logger.debug("line_width: {}".format(line_width)) if line_width < text_width: final_lines.append(line) diff --git a/pkg/pipeline/msgtrun/msgtrun.py b/pkg/pipeline/msgtrun/msgtrun.py index e56c551f..a1116eb4 100644 --- a/pkg/pipeline/msgtrun/msgtrun.py +++ b/pkg/pipeline/msgtrun/msgtrun.py @@ -1,6 +1,6 @@ from __future__ import annotations -from .. import stage, entities, stagemgr +from .. import stage, entities from ...core import entities as core_entities from . import truncator from .truncators import round @@ -14,7 +14,7 @@ class ConversationMessageTruncator(stage.PipelineStage): """ trun: truncator.Truncator - async def initialize(self): + async def initialize(self, pipeline_config: dict): use_method = self.ap.pipeline_cfg.data['msg-truncate']['method'] for trun in truncator.preregistered_truncators: diff --git a/pkg/pipeline/pipelinemgr.py b/pkg/pipeline/pipelinemgr.py index 9f41ba2d..b7eaaab4 100644 --- a/pkg/pipeline/pipelinemgr.py +++ b/pkg/pipeline/pipelinemgr.py @@ -8,10 +8,35 @@ import sqlalchemy from ..core import app, entities from . import entities as pipeline_entities from ..entity.persistence import pipeline as persistence_pipeline -from . import stagemgr, stage +from . import stage from ..platform.types import message as platform_message, events as platform_events from ..plugin import events +from .resprule import resprule +from .bansess import bansess +from .cntfilter import cntfilter +from .process import process +from .longtext import longtext +from .respback import respback +from .wrapper import wrapper +from .preproc import preproc +from .ratelimit import ratelimit +from .msgtrun import msgtrun + + +class StageInstContainer(): + """阶段实例容器 + """ + + inst_name: str + + inst: stage.PipelineStage + + def __init__(self, inst_name: str, inst: stage.PipelineStage): + self.inst_name = inst_name + self.inst = inst + + class RuntimePipeline: """运行时流水线""" @@ -20,10 +45,10 @@ class RuntimePipeline: pipeline_entity: persistence_pipeline.LegacyPipeline """流水线实体""" - stage_containers: list[stagemgr.StageInstContainer] + stage_containers: list[StageInstContainer] """阶段实例容器""" - def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[stagemgr.StageInstContainer]): + def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[StageInstContainer]): self.ap = ap self.pipeline_entity = pipeline_entity self.stage_containers = stage_containers @@ -47,10 +72,18 @@ class RuntimePipeline: *result.user_notice ) - await self.ap.platform_mgr.send( - query.message_event, - result.user_notice, - query.adapter + if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage): + result.user_notice.insert( + 0, + platform_message.At( + query.message_event.sender.id + ) + ) + + await query.adapter.reply_message( + message_source=query.message_event, + message=result.user_notice, + quote_origin=query.pipeline_config['output']['misc']['quote-origin'] ) if result.debug_notice: self.ap.logger.debug(result.debug_notice) @@ -195,12 +228,15 @@ class PipelineManager: pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity) # initialize stage containers according to pipeline_entity.stages - stage_containers = [] + stage_containers: list[StageInstContainer] = [] for stage_name in pipeline_entity.stages: - stage_containers.append(stagemgr.StageInstContainer( + stage_containers.append(StageInstContainer( inst_name=stage_name, inst=self.stage_dict[stage_name](self.ap) )) + + for stage_container in stage_containers: + await stage_container.inst.initialize(pipeline_entity.config) runtime_pipeline = RuntimePipeline(self.ap, pipeline_entity, stage_containers) self.pipelines.append(runtime_pipeline) diff --git a/pkg/pipeline/pool.py b/pkg/pipeline/pool.py index d0c86e31..df4d0741 100644 --- a/pkg/pipeline/pool.py +++ b/pkg/pipeline/pool.py @@ -28,23 +28,23 @@ class QueryPool: async def add_query( self, + bot_uuid: str, launcher_type: entities.LauncherTypes, launcher_id: typing.Union[int, str], sender_id: typing.Union[int, str], message_event: platform_events.MessageEvent, message_chain: platform_message.MessageChain, adapter: msadapter.MessagePlatformAdapter, - pipeline_uuid: str ) -> entities.Query: async with self.condition: query = entities.Query( + bot_uuid=bot_uuid, query_id=self.query_id_counter, launcher_type=launcher_type, launcher_id=launcher_id, sender_id=sender_id, message_event=message_event, message_chain=message_chain, - pipeline_uuid=pipeline_uuid, resp_messages=[], resp_message_chain=[], adapter=adapter diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index 299aea5e..9958466a 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -2,7 +2,7 @@ from __future__ import annotations import datetime -from .. import stage, entities, stagemgr +from .. import stage, entities from ...core import entities as core_entities from ...provider import entities as llm_entities from ...plugin import events @@ -33,16 +33,16 @@ class PreProcessor(stage.PipelineStage): """ session = await self.ap.sess_mgr.get_session(query) - conversation = await self.ap.sess_mgr.get_conversation(session) + conversation = await self.ap.sess_mgr.get_conversation(query, session) # 设置query query.session = session query.prompt = conversation.prompt.copy() query.messages = conversation.messages.copy() - query.use_model = conversation.use_model + query.use_llm_model = conversation.use_llm_model - query.use_funcs = conversation.use_funcs if query.use_model.tool_call_supported else None + query.use_funcs = conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('tool_call') else None query.variables = { "session_id": f"{query.session.launcher_type.value}_{query.session.launcher_id}", @@ -50,8 +50,9 @@ class PreProcessor(stage.PipelineStage): "msg_create_time": int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp()), } - # 检查vision是否启用,没启用就删除所有图片 - if not self.ap.provider_cfg.data['enable-vision'] or (self.ap.provider_cfg.data['runner'] == 'local-agent' and not query.use_model.vision_supported): + # Check if this model supports vision, if not, remove all images + # TODO this checking should be performed in runner, and in this stage, the image should be reserved + if query.pipeline_config['ai']['runner']['runner'] == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'): for msg in query.messages: if isinstance(msg.content, list): for me in msg.content: @@ -69,7 +70,7 @@ class PreProcessor(stage.PipelineStage): ) plain_text += me.text elif isinstance(me, platform_message.Image): - if self.ap.provider_cfg.data['enable-vision'] and (self.ap.provider_cfg.data['runner'] != 'local-agent' or query.use_model.vision_supported): + if query.pipeline_config['ai']['runner']['runner'] != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__('vision'): if me.base64 is not None: content_list.append( llm_entities.ContentElement.from_image_base64(me.base64) diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 83bb3335..9d231dda 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -9,7 +9,9 @@ import json from .. import handler from ... import entities from ....core import entities as core_entities -from ....provider import entities as llm_entities, runnermgr +from ....provider import entities as llm_entities +from ....provider import runner as runner_module +from ....provider.runners import localagent, difysvapi, dashscopeapi from ....plugin import events from ....platform.types import message as platform_message @@ -56,12 +58,6 @@ class ChatMessageHandler(handler.MessageHandler): ) else: - if not self.ap.provider_cfg.data['enable-chat']: - yield entities.StageProcessResult( - result_type=entities.ResultType.INTERRUPT, - new_query=query, - ) - if event_ctx.event.alter is not None: # if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter query.user_message.content = event_ctx.event.alter @@ -72,7 +68,12 @@ class ChatMessageHandler(handler.MessageHandler): try: - runner = self.ap.runner_mgr.get_runner() + for r in runner_module.preregistered_runners: + if r.name == query.pipeline_config["ai"]["runner"]["runner"]: + runner = r(self.ap, query.pipeline_config) + break + else: + raise ValueError(f"未找到请求运行器: {query.pipeline_config['ai']['runner']['runner']}") async for result in runner.run(query): query.resp_messages.append(result) @@ -93,10 +94,12 @@ class ChatMessageHandler(handler.MessageHandler): self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}') + hide_exception_info = query.pipeline_config['output']['misc']['hide-exception'] + yield entities.StageProcessResult( result_type=entities.ResultType.INTERRUPT, new_query=query, - user_notice='请求失败' if self.ap.platform_cfg.data['hide-exception-info'] else f'{e}', + user_notice='请求失败' if hide_exception_info else f'{e}', error_notice=f'{e}', debug_notice=traceback.format_exc() ) diff --git a/pkg/pipeline/process/process.py b/pkg/pipeline/process/process.py index 362ece01..ea4d7e7f 100644 --- a/pkg/pipeline/process/process.py +++ b/pkg/pipeline/process/process.py @@ -4,7 +4,7 @@ from ...core import app, entities as core_entities from . import handler from .handlers import chat, command from .. import entities -from .. import stage, entities, stagemgr +from .. import stage, entities from ...core import entities as core_entities from ...config import manager as cfg_mgr @@ -23,7 +23,7 @@ class Processor(stage.PipelineStage): chat_handler: handler.MessageHandler - async def initialize(self): + async def initialize(self, pipeline_config: dict): self.cmd_handler = command.CommandHandler(self.ap) self.chat_handler = chat.ChatMessageHandler(self.ap) diff --git a/pkg/pipeline/ratelimit/ratelimit.py b/pkg/pipeline/ratelimit/ratelimit.py index cd39b85c..01bde395 100644 --- a/pkg/pipeline/ratelimit/ratelimit.py +++ b/pkg/pipeline/ratelimit/ratelimit.py @@ -2,7 +2,7 @@ from __future__ import annotations import typing -from .. import entities, stagemgr, stage +from .. import entities, stage from . import algo from .algos import fixedwin from ...core import entities as core_entities @@ -18,7 +18,7 @@ class RateLimit(stage.PipelineStage): algo: algo.ReteLimitAlgo - async def initialize(self): + async def initialize(self, pipeline_config: dict): algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo'] diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index 08b335d5..8c074d89 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -5,8 +5,10 @@ import asyncio from ...core import app +from ...platform.types import events as platform_events +from ...platform.types import message as platform_message -from .. import stage, entities, stagemgr +from .. import stage, entities from ...core import entities as core_entities from ...config import manager as cfg_mgr @@ -19,8 +21,8 @@ class SendResponseBackStage(stage.PipelineStage): async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: """处理 """ - - random_range = (self.ap.platform_cfg.data['force-delay']['min'], self.ap.platform_cfg.data['force-delay']['max']) + + random_range = (query.pipeline_config['output']['force-delay']['min'], query.pipeline_config['output']['force-delay']['max']) random_delay = random.uniform(*random_range) @@ -31,10 +33,20 @@ class SendResponseBackStage(stage.PipelineStage): await asyncio.sleep(random_delay) - await self.ap.platform_mgr.send( - query.message_event, - query.resp_message_chain[-1], - adapter=query.adapter + if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage): + query.resp_message_chain[-1].insert( + 0, + platform_message.At( + query.message_event.sender.id + ) + ) + + quote_origin = query.pipeline_config['output']['misc']['quote-origin'] + + await query.adapter.reply_message( + message_source=query.message_event, + message=query.resp_message_chain[-1], + quote_origin=quote_origin ) return entities.StageProcessResult( diff --git a/pkg/pipeline/resprule/resprule.py b/pkg/pipeline/resprule/resprule.py index 77858f0d..7e4b8f99 100644 --- a/pkg/pipeline/resprule/resprule.py +++ b/pkg/pipeline/resprule/resprule.py @@ -5,7 +5,7 @@ from ...core import app from . import entities as rule_entities, rule from .rules import atbot, prefix, regexp, random -from .. import stage, entities, stagemgr +from .. import stage, entities from ...core import entities as core_entities from ...config import manager as cfg_mgr @@ -20,7 +20,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage): rule_matchers: list[rule.GroupRespondRule] """检查器实例""" - async def initialize(self): + async def initialize(self, pipeline_config: dict): """初始化检查器 """ diff --git a/pkg/pipeline/stage.py b/pkg/pipeline/stage.py index 206f2bdf..859286d9 100644 --- a/pkg/pipeline/stage.py +++ b/pkg/pipeline/stage.py @@ -28,7 +28,7 @@ class PipelineStage(metaclass=abc.ABCMeta): def __init__(self, ap: app.Application): self.ap = ap - async def initialize(self): + async def initialize(self, pipeline_config: dict): """初始化 """ pass diff --git a/pkg/pipeline/stagemgr.py b/pkg/pipeline/stagemgr.py deleted file mode 100644 index 19fce2d6..00000000 --- a/pkg/pipeline/stagemgr.py +++ /dev/null @@ -1,71 +0,0 @@ -from __future__ import annotations - -from ..core import app -from . import stage -from .resprule import resprule -from .bansess import bansess -from .cntfilter import cntfilter -from .process import process -from .longtext import longtext -from .respback import respback -from .wrapper import wrapper -from .preproc import preproc -from .ratelimit import ratelimit -from .msgtrun import msgtrun - - -# 请求处理阶段顺序 -stage_order = [ - "GroupRespondRuleCheckStage", # 群响应规则检查 - "BanSessionCheckStage", # 封禁会话检查 - "PreContentFilterStage", # 内容过滤前置阶段 - "PreProcessor", # 预处理器 - "ConversationMessageTruncator", # 会话消息截断器 - "RequireRateLimitOccupancy", # 请求速率限制占用 - "MessageProcessor", # 处理器 - "ReleaseRateLimitOccupancy", # 释放速率限制占用 - "PostContentFilterStage", # 内容过滤后置阶段 - "ResponseWrapper", # 响应包装器 - "LongTextProcessStage", # 长文本处理 - "SendResponseBackStage", # 发送响应 -] - - -class StageInstContainer(): - """阶段实例容器 - """ - - inst_name: str - - inst: stage.PipelineStage - - def __init__(self, inst_name: str, inst: stage.PipelineStage): - self.inst_name = inst_name - self.inst = inst - - -class StageManager: - ap: app.Application - - stage_containers: list[StageInstContainer] - - def __init__(self, ap: app.Application): - self.ap = ap - - self.stage_containers = [] - - async def initialize(self): - """初始化 - """ - - for name, cls in stage.preregistered_stages.items(): - self.stage_containers.append(StageInstContainer( - inst_name=name, - inst=cls(self.ap) - )) - - for stage_containers in self.stage_containers: - await stage_containers.inst.initialize() - - # 按照 stage_order 排序 - self.stage_containers.sort(key=lambda x: stage_order.index(x.inst_name)) diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py index a06e4a80..6b12ca65 100644 --- a/pkg/pipeline/wrapper/wrapper.py +++ b/pkg/pipeline/wrapper/wrapper.py @@ -5,7 +5,7 @@ import typing from ...core import app, entities as core_entities from .. import entities -from .. import stage, entities, stagemgr +from .. import stage, entities from ...core import entities as core_entities from ...config import manager as cfg_mgr from ...plugin import events @@ -22,7 +22,7 @@ class ResponseWrapper(stage.PipelineStage): - resp_message_chain """ - async def initialize(self): + async def initialize(self, pipeline_config: dict): pass async def process( @@ -110,7 +110,7 @@ class ResponseWrapper(stage.PipelineStage): query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)])) - if self.ap.platform_cfg.data['track-function-calls']: + if query.pipeline_config['output']['misc']['track-function-calls']: event_ctx = await self.ap.plugin_mgr.emit_event( event=events.NormalMessageResponded( diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index 81f15655..360f7588 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -55,25 +55,25 @@ class RuntimeBot: async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessagePlatformAdapter): await self.ap.query_pool.add_query( + bot_uuid=self.bot_entity.uuid, launcher_type=core_entities.LauncherTypes.PERSON, launcher_id=event.sender.id, sender_id=event.sender.id, message_event=event, message_chain=event.message_chain, adapter=adapter, - pipeline_uuid=self.bot_entity.use_pipeline_uuid ) async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessagePlatformAdapter): await self.ap.query_pool.add_query( + bot_uuid=self.bot_entity.uuid, launcher_type=core_entities.LauncherTypes.GROUP, launcher_id=event.group.id, sender_id=event.sender.id, message_event=event, message_chain=event.message_chain, adapter=adapter, - pipeline_uuid=self.bot_entity.use_pipeline_uuid ) self.adapter.register_listener( @@ -113,14 +113,16 @@ class RuntimeBot: async def shutdown(self): await self.adapter.kill() + self.ap.task_mgr.cancel_task(self.task_wrapper.id) + # 控制QQ消息输入输出的类 class PlatformManager: # adapter: msadapter.MessageSourceAdapter = None - adapters: list[msadapter.MessagePlatformAdapter] = [] + adapters: list[msadapter.MessagePlatformAdapter] = [] # deprecated - message_platform_adapter_components: list[engine.Component] = [] + message_platform_adapter_components: list[engine.Component] = [] # deprecated # ====== 4.0 ====== ap: app.Application = None @@ -215,50 +217,36 @@ class PlatformManager: return None async def write_back_config(self, adapter_name: str, adapter_inst: msadapter.MessagePlatformAdapter, config: dict): - index = -2 + # index = -2 - for i, adapter in enumerate(self.adapters): - if adapter == adapter_inst: - index = i - break + # for i, adapter in enumerate(self.adapters): + # if adapter == adapter_inst: + # index = i + # break - if index == -2: - raise Exception('平台适配器未找到') + # if index == -2: + # raise Exception('平台适配器未找到') - # 只修改启用的适配器 - real_index = -1 + # # 只修改启用的适配器 + # real_index = -1 - for i, adapter in enumerate(self.ap.platform_cfg.data['platform-adapters']): - if adapter['enable']: - index -= 1 - if index == -1: - real_index = i - break + # for i, adapter in enumerate(self.ap.platform_cfg.data['platform-adapters']): + # if adapter['enable']: + # index -= 1 + # if index == -1: + # real_index = i + # break - new_cfg = { - 'adapter': adapter_name, - 'enable': True, - **config - } - self.ap.platform_cfg.data['platform-adapters'][real_index] = new_cfg - await self.ap.platform_cfg.dump_config() + # new_cfg = { + # 'adapter': adapter_name, + # 'enable': True, + # **config + # } + # self.ap.platform_cfg.data['platform-adapters'][real_index] = new_cfg + # await self.ap.platform_cfg.dump_config() - async def send(self, event: platform_events.MessageEvent, msg: platform_message.MessageChain, adapter: msadapter.MessagePlatformAdapter): - - if self.ap.platform_cfg.data['at-sender'] and isinstance(event, platform_events.GroupMessage): - - msg.insert( - 0, - platform_message.At( - event.sender.id - ) - ) - - await adapter.reply_message( - event, - msg, - quote_origin=True if self.ap.platform_cfg.data['quote-origin'] else False - ) + # TODO implement this + pass async def run(self): # This method will only be called when the application launching @@ -270,4 +258,4 @@ class PlatformManager: for bot in self.bots: if bot.enable: await bot.shutdown() - self.ap.task_mgr.cancel_by_scope(core_entities.LifecycleControlScope.PLATFORM) \ No newline at end of file + self.ap.task_mgr.cancel_by_scope(core_entities.LifecycleControlScope.PLATFORM) diff --git a/pkg/platform/sources/aiocqhttp.py b/pkg/platform/sources/aiocqhttp.py index af14372a..9149e427 100644 --- a/pkg/platform/sources/aiocqhttp.py +++ b/pkg/platform/sources/aiocqhttp.py @@ -238,4 +238,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): await self.bot._server_app.run_task(**self.config) async def kill(self) -> bool: + # Current issue: existing connection will not be closed + # self.should_shutdown = True return False diff --git a/pkg/plugin/context.py b/pkg/plugin/context.py index 76a49bf4..7a9be2a1 100644 --- a/pkg/plugin/context.py +++ b/pkg/plugin/context.py @@ -222,10 +222,10 @@ class EventContext: Args: message_chain (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链 """ - await self.host.ap.platform_mgr.send( - event=self.event.query.message_event, - msg=message_chain, - adapter=self.event.query.adapter, + # TODO 添加 at_sender 和 quote_origin 参数 + await self.event.query.adapter.reply_message( + message_source=self.event.query.message_event, + message=message_chain ) async def send_message( diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index 7db7a040..41e97f3e 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -2,6 +2,7 @@ from __future__ import annotations import typing import sqlalchemy +import pydantic.v1 as pydantic from . import entities, requester from ...core import app @@ -16,23 +17,6 @@ from .requesters import bailianchatcmpl, chatcmpl, anthropicmsgs, moonshotchatcm FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list" -class RuntimeLLMModel: - """运行时模型""" - - model_entity: persistence_model.LLMModel - """模型数据""" - - token_mgr: token.TokenManager - """api key管理器""" - - requester: requester.LLMAPIRequester - """请求器实例""" - - def __init__(self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, requester: requester.LLMAPIRequester): - self.model_entity = model_entity - self.token_mgr = token_mgr - self.requester = requester - class ModelManager: """模型管理器""" @@ -47,7 +31,7 @@ class ModelManager: ap: app.Application - llm_models: list[RuntimeLLMModel] + llm_models: list[requester.RuntimeLLMModel] requester_components: list[engine.Component] @@ -99,16 +83,20 @@ class ModelManager: elif isinstance(model_info, dict): model_info = persistence_model.LLMModel(**model_info) - runtime_llm_model = RuntimeLLMModel( + requester_inst = self.requester_dict[model_info.requester]( + ap=self.ap, + config=model_info.requester_config + ) + + await requester_inst.initialize() + + runtime_llm_model = requester.RuntimeLLMModel( model_entity=model_info, token_mgr=token.TokenManager( name=model_info.uuid, tokens=model_info.api_keys, ), - requester=self.requester_dict[model_info.requester]( - ap=self.ap, - config=model_info.requester_config - ) + requester=requester_inst ) self.llm_models.append(runtime_llm_model) diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 7f13c58b..5ea8d23f 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -6,8 +6,27 @@ import typing from ...core import app from ...core import entities as core_entities from .. import entities as llm_entities -from . import entities as modelmgr_entities from ..tools import entities as tools_entities +from ...entity.persistence import model as persistence_model +from . import token + + +class RuntimeLLMModel: + """运行时模型""" + + model_entity: persistence_model.LLMModel + """模型数据""" + + token_mgr: token.TokenManager + """api key管理器""" + + requester: LLMAPIRequester + """请求器实例""" + + def __init__(self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, requester: LLMAPIRequester): + self.model_entity = model_entity + self.token_mgr = token_mgr + self.requester = requester class LLMAPIRequester(metaclass=abc.ABCMeta): @@ -31,21 +50,11 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): async def initialize(self): pass - async def preprocess( - self, - query: core_entities.Query, - ): - """预处理 - - 在这里处理特定API对Query对象的兼容性问题。 - """ - pass - @abc.abstractmethod - async def call( + async def invoke_llm( self, query: core_entities.Query, - model: modelmgr_entities.LLMModelInfo, + model: RuntimeLLMModel, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, @@ -53,7 +62,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): """调用API Args: - model (modelmgr_entities.LLMModelInfo): 使用的模型信息 + model (RuntimeLLMModel): 使用的模型信息 messages (typing.List[llm_entities.Message]): 消息对象列表 funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None. extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.py b/pkg/provider/modelmgr/requesters/anthropicmsgs.py index 937f5107..7edc4405 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.py +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.py @@ -24,16 +24,16 @@ class AnthropicMessages(requester.LLMAPIRequester): client: anthropic.AsyncAnthropic default_config: dict[str, typing.Any] = { - 'base-url': 'https://api.anthropic.com/v1', + 'base_url': 'https://api.anthropic.com/v1', 'timeout': 120, } async def initialize(self): httpx_client = anthropic._base_client.AsyncHttpxClientWrapper( - base_url=self.ap.provider_cfg.data['requester']['anthropic-messages']['base-url'], + base_url=self.requester_cfg['base_url'], # cast to a valid type because mypy doesn't understand our type narrowing - timeout=typing.cast(httpx.Timeout, self.ap.provider_cfg.data['requester']['anthropic-messages']['timeout']), + timeout=typing.cast(httpx.Timeout, self.requester_cfg['timeout']), limits=anthropic._constants.DEFAULT_CONNECTION_LIMITS, follow_redirects=True, trust_env=True, @@ -44,17 +44,18 @@ class AnthropicMessages(requester.LLMAPIRequester): http_client=httpx_client, ) - async def call( + async def invoke_llm( self, query: core_entities.Query, - model: entities.LLMModelInfo, + model: requester.RuntimeLLMModel, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: self.client.api_key = model.token_mgr.get_token() - args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy() - args["model"] = model.name if model.model_name is None else model.model_name + args = extra_args.copy() + args["model"] = model.model_entity.name # 处理消息 diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml b/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml index 80380857..6d1a53cf 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: Anthropic spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/bailianchatcmpl.py b/pkg/provider/modelmgr/requesters/bailianchatcmpl.py index ce718c4c..e20e3376 100644 --- a/pkg/provider/modelmgr/requesters/bailianchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/bailianchatcmpl.py @@ -14,6 +14,6 @@ class BailianChatCompletions(chatcmpl.OpenAIChatCompletions): client: openai.AsyncClient default_config: dict[str, typing.Any] = { - 'base-url': 'https://dashscope.aliyuncs.com/compatible-mode/v1', + 'base_url': 'https://dashscope.aliyuncs.com/compatible-mode/v1', 'timeout': 120, } diff --git a/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml b/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml index 74d197ca..136e903f 100644 --- a/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: 阿里云百炼 spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index 7cf255c0..d1d9767f 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -26,7 +26,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): client: openai.AsyncClient default_config: dict[str, typing.Any] = { - "base-url": "https://api.openai.com/v1", + "base_url": "https://api.openai.com/v1", "timeout": 120, } @@ -34,7 +34,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): self.client = openai.AsyncClient( api_key="", - base_url=self.requester_cfg["base-url"], + base_url=self.requester_cfg["base_url"], timeout=self.requester_cfg["timeout"], http_client=httpx.AsyncClient( trust_env=True, timeout=self.requester_cfg["timeout"] @@ -65,16 +65,14 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): self, query: core_entities.Query, req_messages: list[dict], - use_model: entities.LLMModelInfo, + use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, # TODO: 所有的args都改为从此参数读取 ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() - args = self.requester_cfg["args"].copy() - args["model"] = ( - use_model.name if use_model.model_name is None else use_model.model_name - ) + args = extra_args.copy() + args["model"] = use_model.model_entity.name if use_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) @@ -104,10 +102,10 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): return message - async def call( + async def invoke_llm( self, query: core_entities.Query, - model: entities.LLMModelInfo, + model: requester.RuntimeLLMModel, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.yaml b/pkg/provider/modelmgr/requesters/chatcmpl.yaml index fe4d3cb5..a67e23d1 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/chatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: OpenAI spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py index e82d0d81..ee17ac05 100644 --- a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py @@ -13,7 +13,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): """Deepseek ChatCompletion API 请求器""" default_config: dict[str, typing.Any] = { - 'base-url': 'https://api.deepseek.com', + 'base_url': 'https://api.deepseek.com', 'timeout': 120, } @@ -21,14 +21,14 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): self, query: core_entities.Query, req_messages: list[dict], - use_model: entities.LLMModelInfo, + use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() - args = self.requester_cfg['args'].copy() - args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + args = extra_args.copy() + args["model"] = use_model.model_entity.name if use_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) diff --git a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml index 2ef91aa2..9890e21e 100644 --- a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: 深度求索 spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py index dec0b8d1..35052682 100644 --- a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py @@ -18,7 +18,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): """Gitee AI ChatCompletions API 请求器""" default_config: dict[str, typing.Any] = { - 'base-url': 'https://ai.gitee.com/v1', + 'base_url': 'https://ai.gitee.com/v1', 'timeout': 120, } @@ -26,14 +26,14 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): self, query: core_entities.Query, req_messages: list[dict], - use_model: entities.LLMModelInfo, + use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() - args = self.requester_cfg['args'].copy() - args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + args = extra_args.copy() + args["model"] = use_model.model_entity.name if use_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) diff --git a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml index 11f7e06e..3e4efd61 100644 --- a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: Gitee AI spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py index c00be372..6be76051 100644 --- a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py +++ b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py @@ -14,6 +14,6 @@ class LmStudioChatCompletions(chatcmpl.OpenAIChatCompletions): client: openai.AsyncClient default_config: dict[str, typing.Any] = { - 'base-url': 'http://127.0.0.1:1234/v1', + 'base_url': 'http://127.0.0.1:1234/v1', 'timeout': 120, } diff --git a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml index 959d4151..219e5839 100644 --- a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: LM Studio spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py index 3cbe8837..1ef7d9c9 100644 --- a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py @@ -15,7 +15,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): """Moonshot ChatCompletion API 请求器""" default_config: dict[str, typing.Any] = { - 'base-url': 'https://api.moonshot.cn/v1', + 'base_url': 'https://api.moonshot.cn/v1', 'timeout': 120, } @@ -23,14 +23,14 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): self, query: core_entities.Query, req_messages: list[dict], - use_model: entities.LLMModelInfo, + use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() - args = self.requester_cfg['args'].copy() - args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + args = extra_args.copy() + args["model"] = use_model.model_entity.name if use_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) diff --git a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml index 56deb1df..7290784f 100644 --- a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: 月之暗面 spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/ollamachat.py b/pkg/provider/modelmgr/requesters/ollamachat.py index fa99cfe5..ee331036 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.py +++ b/pkg/provider/modelmgr/requesters/ollamachat.py @@ -22,35 +22,38 @@ REQUESTER_NAME: str = "ollama-chat" class OllamaChatCompletions(requester.LLMAPIRequester): """Ollama平台 ChatCompletion API请求器""" + client: ollama.AsyncClient default_config: dict[str, typing.Any] = { - 'base-url': 'http://127.0.0.1:11434', - 'timeout': 120, + "base_url": "http://127.0.0.1:11434", + "timeout": 120, } async def initialize(self): - os.environ['OLLAMA_HOST'] = self.requester_cfg['base-url'] - self.client = ollama.AsyncClient( - timeout=self.requester_cfg['timeout'] - ) + os.environ["OLLAMA_HOST"] = self.requester_cfg["base_url"] + self.client = ollama.AsyncClient(timeout=self.requester_cfg["timeout"]) - async def _req(self, - args: dict, - ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: - return await self.client.chat( - **args - ) + async def _req( + self, + args: dict, + ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + return await self.client.chat(**args) - async def _closure(self, query: core_entities.Query, req_messages: list[dict], use_model: entities.LLMModelInfo, - user_funcs: list[tools_entities.LLMFunction] = None, - extra_args: dict[str, typing.Any] = {}) -> llm_entities.Message: - args: Any = self.requester_cfg['args'].copy() - args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + async def _closure( + self, + query: core_entities.Query, + req_messages: list[dict], + use_model: requester.RuntimeLLMModel, + user_funcs: list[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, + ) -> llm_entities.Message: + args = extra_args.copy() + args["model"] = use_model.model_entity.name messages: list[dict] = req_messages.copy() for msg in messages: - if 'content' in msg and isinstance(msg["content"], list): + if "content" in msg and isinstance(msg["content"], list): text_content: list = [] image_urls: list = [] for me in msg["content"]: @@ -58,12 +61,16 @@ class OllamaChatCompletions(requester.LLMAPIRequester): text_content.append(me["text"]) elif me["type"] == "image_base64": image_urls.append(me["image_base64"]) - + msg["content"] = "\n".join(text_content) - msg["images"] = [url.split(',')[1] for url in image_urls] - if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict - for tool_call in msg['tool_calls']: - tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments']) + msg["images"] = [url.split(",")[1] for url in image_urls] + if ( + "tool_calls" in msg + ): # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict + for tool_call in msg["tool_calls"]: + tool_call["function"]["arguments"] = json.loads( + tool_call["function"]["arguments"] + ) args["messages"] = messages args["tools"] = [] @@ -77,8 +84,8 @@ class OllamaChatCompletions(requester.LLMAPIRequester): return message async def _make_msg( - self, - chat_completions: ollama.ChatResponse) -> llm_entities.Message: + self, chat_completions: ollama.ChatResponse + ) -> llm_entities.Message: message: ollama.Message = chat_completions.message if message is None: raise ValueError("chat_completions must contain a 'message' field") @@ -86,43 +93,51 @@ class OllamaChatCompletions(requester.LLMAPIRequester): ret_msg: llm_entities.Message = None if message.content is not None: - ret_msg = llm_entities.Message( - role="assistant", - content=message.content - ) + ret_msg = llm_entities.Message(role="assistant", content=message.content) if message.tool_calls is not None and len(message.tool_calls) > 0: tool_calls: list[llm_entities.ToolCall] = [] for tool_call in message.tool_calls: - tool_calls.append(llm_entities.ToolCall( - id=uuid.uuid4().hex, - type="function", - function=llm_entities.FunctionCall( - name=tool_call.function.name, - arguments=json.dumps(tool_call.function.arguments) + tool_calls.append( + llm_entities.ToolCall( + id=uuid.uuid4().hex, + type="function", + function=llm_entities.FunctionCall( + name=tool_call.function.name, + arguments=json.dumps(tool_call.function.arguments), + ), ) - )) + ) ret_msg.tool_calls = tool_calls return ret_msg - async def call( - self, - query: core_entities.Query, - model: entities.LLMModelInfo, - messages: typing.List[llm_entities.Message], - funcs: typing.List[tools_entities.LLMFunction] = None, - extra_args: dict[str, typing.Any] = {}, + async def invoke_llm( + self, + query: core_entities.Query, + model: requester.RuntimeLLMModel, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: req_messages: list = [] for m in messages: msg_dict: dict = m.dict(exclude_none=True) content: Any = msg_dict.get("content") if isinstance(content, list): - if all(isinstance(part, dict) and part.get('type') == 'text' for part in content): + if all( + isinstance(part, dict) and part.get("type") == "text" + for part in content + ): msg_dict["content"] = "\n".join(part["text"] for part in content) req_messages.append(msg_dict) try: - return await self._closure(query, req_messages, model, funcs, extra_args) + return await self._closure( + query=query, + req_messages=req_messages, + use_model=model, + use_funcs=funcs, + extra_args=extra_args, + ) except asyncio.TimeoutError: - raise errors.RequesterError('请求超时') + raise errors.RequesterError("请求超时") diff --git a/pkg/provider/modelmgr/requesters/ollamachat.yaml b/pkg/provider/modelmgr/requesters/ollamachat.yaml index b162e5db..ba915aeb 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.yaml +++ b/pkg/provider/modelmgr/requesters/ollamachat.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: Ollama spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py index a990f809..dd5b9a14 100644 --- a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py @@ -14,6 +14,6 @@ class SiliconFlowChatCompletions(chatcmpl.OpenAIChatCompletions): client: openai.AsyncClient default_config: dict[str, typing.Any] = { - 'base-url': 'https://api.siliconflow.cn/v1', + 'base_url': 'https://api.siliconflow.cn/v1', 'timeout': 120, } diff --git a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml index 28d534f6..c938b21c 100644 --- a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: 硅基流动 spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py index fbf88826..9b5505e1 100644 --- a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py @@ -14,6 +14,6 @@ class VolcArkChatCompletions(chatcmpl.OpenAIChatCompletions): client: openai.AsyncClient default_config: dict[str, typing.Any] = { - 'base-url': 'https://ark.cn-beijing.volces.com/api/v3', + 'base_url': 'https://ark.cn-beijing.volces.com/api/v3', 'timeout': 120, } diff --git a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml index f18c7b2c..56347bc5 100644 --- a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: 火山方舟 spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/xaichatcmpl.py b/pkg/provider/modelmgr/requesters/xaichatcmpl.py index 47c2939a..e08af875 100644 --- a/pkg/provider/modelmgr/requesters/xaichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/xaichatcmpl.py @@ -14,6 +14,6 @@ class XaiChatCompletions(chatcmpl.OpenAIChatCompletions): client: openai.AsyncClient default_config: dict[str, typing.Any] = { - 'base-url': 'https://api.x.ai/v1', + 'base_url': 'https://api.x.ai/v1', 'timeout': 120, } diff --git a/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml index ceda8c0d..604b88c6 100644 --- a/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: xAI spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py index 1e24a5ef..7bbca164 100644 --- a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py @@ -14,6 +14,6 @@ class ZhipuAIChatCompletions(chatcmpl.OpenAIChatCompletions): client: openai.AsyncClient default_config: dict[str, typing.Any] = { - 'base-url': 'https://open.bigmodel.cn/api/paas/v4', + 'base_url': 'https://open.bigmodel.cn/api/paas/v4', 'timeout': 120, } diff --git a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml index 3d112ca1..20b8b496 100644 --- a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: 智谱 AI spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/runner.py b/pkg/provider/runner.py index 5a5cf6ef..1762e546 100644 --- a/pkg/provider/runner.py +++ b/pkg/provider/runner.py @@ -27,11 +27,11 @@ class RequestRunner(abc.ABC): ap: app.Application - def __init__(self, ap: app.Application): - self.ap = ap + pipeline_config: dict - async def initialize(self): - pass + def __init__(self, ap: app.Application, pipeline_config: dict): + self.ap = ap + self.pipeline_config = pipeline_config @abc.abstractmethod async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: diff --git a/pkg/provider/runnermgr.py b/pkg/provider/runnermgr.py deleted file mode 100644 index 52e1d8d2..00000000 --- a/pkg/provider/runnermgr.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -from . import runner -from ..core import app - -from .runners import localagent -from .runners import difysvapi -from .runners import dashscopeapi - -class RunnerManager: - - ap: app.Application - - using_runner: runner.RequestRunner - - def __init__(self, ap: app.Application): - self.ap = ap - - async def initialize(self): - - for r in runner.preregistered_runners: - if r.name == self.ap.provider_cfg.data['runner']: - self.using_runner = r(self.ap) - await self.using_runner.initialize() - break - else: - raise ValueError(f"未找到请求运行器: {self.ap.provider_cfg.data['runner']}") - - def get_runner(self) -> runner.RequestRunner: - return self.using_runner diff --git a/pkg/provider/runners/dashscopeapi.py b/pkg/provider/runners/dashscopeapi.py index 0bb09822..d5e6a83d 100644 --- a/pkg/provider/runners/dashscopeapi.py +++ b/pkg/provider/runners/dashscopeapi.py @@ -8,7 +8,7 @@ import re import dashscope from .. import runner -from ...core import entities as core_entities +from ...core import app, entities as core_entities from .. import entities as llm_entities from ...utils import image @@ -29,12 +29,14 @@ class DashScopeAPIRunner(runner.RequestRunner): app_id: str # 应用ID api_key: str # API Key references_quote: str # 引用资料提示(当展示回答来源功能开启时,这个变量会作为引用资料名前的提示,可在provider.json中配置) - biz_params: dict = {} # 工作流应用参数(仅在工作流应用中生效) - async def initialize(self): + def __init__(self, ap: app.Application, pipeline_config: dict): """初始化""" + self.ap = ap + self.pipeline_config = pipeline_config + valid_app_types = ["agent", "workflow"] - self.app_type = self.ap.provider_cfg.data["dashscope-app-api"]["app-type"] + self.app_type = self.pipeline_config["ai"]["dashscope-app-api"]["app-type"] #检查配置文件中使用的应用类型是否支持 if (self.app_type not in valid_app_types): raise DashscopeAPIError( @@ -42,10 +44,9 @@ class DashScopeAPIRunner(runner.RequestRunner): ) #初始化Dashscope 参数配置 - self.app_id = self.ap.provider_cfg.data["dashscope-app-api"][self.app_type]["app-id"] - self.api_key = self.ap.provider_cfg.data["dashscope-app-api"]["api-key"] - self.references_quote = self.ap.provider_cfg.data["dashscope-app-api"][self.app_type]["references_quote"] - self.biz_params = self.ap.provider_cfg.data["dashscope-app-api"]["workflow"]["biz_params"] + self.app_id = self.pipeline_config["ai"]["dashscope-app-api"]["app-id"] + self.api_key = self.pipeline_config["ai"]["dashscope-app-api"]["api-key"] + self.references_quote = self.pipeline_config["ai"]["dashscope-app-api"]["references_quote"] def _replace_references(self, text, references_dict): """阿里云百炼平台的自定义应用支持资料引用,此函数可以将引用标签替换为参考资料""" @@ -169,7 +170,6 @@ class DashScopeAPIRunner(runner.RequestRunner): plain_text, image_ids = await self._preprocess_user_message(query) biz_params = {} - biz_params.update(self.biz_params) biz_params.update(query.variables) #发送对话请求 @@ -220,21 +220,19 @@ class DashScopeAPIRunner(runner.RequestRunner): content=pending_content, ) - - async def run( self, query: core_entities.Query ) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行""" - if self.ap.provider_cfg.data["dashscope-app-api"]["app-type"] == "agent": + if self.app_type == "agent": async for msg in self._agent_messages(query): yield msg - elif self.ap.provider_cfg.data["dashscope-app-api"]["app-type"] == "workflow": + elif self.app_type == "workflow": async for msg in self._workflow_messages(query): yield msg else: raise DashscopeAPIError( - f"不支持的 Dashscope 应用类型: {self.ap.provider_cfg.data['dashscope-app-api']['app-type']}" + f"不支持的 Dashscope 应用类型: {self.app_type}" ) diff --git a/pkg/provider/runners/difysvapi.py b/pkg/provider/runners/difysvapi.py index 81ceddee..f48cbd57 100644 --- a/pkg/provider/runners/difysvapi.py +++ b/pkg/provider/runners/difysvapi.py @@ -10,7 +10,7 @@ import datetime import aiohttp from .. import runner -from ...core import entities as core_entities +from ...core import app, entities as core_entities from .. import entities as llm_entities from ...utils import image @@ -23,24 +23,24 @@ class DifyServiceAPIRunner(runner.RequestRunner): dify_client: client.AsyncDifyServiceClient - async def initialize(self): - """初始化""" + def __init__(self, ap: app.Application, pipeline_config: dict): + self.ap = ap + self.pipeline_config = pipeline_config + valid_app_types = ["chat", "agent", "workflow"] if ( - self.ap.provider_cfg.data["dify-service-api"]["app-type"] + self.pipeline_config["ai"]["dify-service-api"]["app-type"] not in valid_app_types ): raise errors.DifyAPIError( - f"不支持的 Dify 应用类型: {self.ap.provider_cfg.data['dify-service-api']['app-type']}" + f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}" ) - api_key = self.ap.provider_cfg.data["dify-service-api"][ - self.ap.provider_cfg.data["dify-service-api"]["app-type"] - ]["api-key"] + api_key = self.pipeline_config["ai"]["dify-service-api"]["api-key"] self.dify_client = client.AsyncDifyServiceClient( api_key=api_key, - base_url=self.ap.provider_cfg.data["dify-service-api"]["base-url"], + base_url=self.pipeline_config["ai"]["dify-service-api"]["base-url"], ) def _try_convert_thinking(self, resp_text: str) -> str: @@ -48,13 +48,13 @@ class DifyServiceAPIRunner(runner.RequestRunner): if not resp_text.startswith("
Thinking... "): return resp_text - if self.ap.provider_cfg.data["dify-service-api"]["options"]["convert-thinking-tips"] == "original": + if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "original": return resp_text - if self.ap.provider_cfg.data["dify-service-api"]["options"]["convert-thinking-tips"] == "remove": + if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "remove": return re.sub(r'
Thinking... .*?
', '', resp_text, flags=re.DOTALL) - if self.ap.provider_cfg.data["dify-service-api"]["options"]["convert-thinking-tips"] == "plain": + if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "plain": pattern = r'
Thinking... (.*?)
' thinking_text = re.search(pattern, resp_text, flags=re.DOTALL) content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL) @@ -121,7 +121,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): user=f"{query.session.launcher_type.value}_{query.session.launcher_id}", conversation_id=cov_id, files=files, - timeout=self.ap.provider_cfg.data["dify-service-api"]["chat"]["timeout"], + timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"], ): self.ap.logger.debug("dify-chat-chunk: " + str(chunk)) @@ -177,7 +177,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): response_mode="streaming", conversation_id=cov_id, files=files, - timeout=self.ap.provider_cfg.data["dify-service-api"]["chat"]["timeout"], + timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"], ): self.ap.logger.debug("dify-agent-chunk: " + str(chunk)) @@ -264,7 +264,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): inputs=inputs, user=f"{query.session.launcher_type.value}_{query.session.launcher_id}", files=files, - timeout=self.ap.provider_cfg.data["dify-service-api"]["workflow"]["timeout"], + timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"], ): self.ap.logger.debug("dify-workflow-chunk: " + str(chunk)) if chunk["event"] in ignored_events: @@ -301,11 +301,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): msg = llm_entities.Message( role="assistant", - content=chunk["data"]["outputs"][ - self.ap.provider_cfg.data["dify-service-api"]["workflow"][ - "output-key" - ] - ], + content=chunk["data"]["outputs"]["summary"], ) yield msg @@ -314,16 +310,16 @@ class DifyServiceAPIRunner(runner.RequestRunner): self, query: core_entities.Query ) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行请求""" - if self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "chat": + if self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "chat": async for msg in self._chat_messages(query): yield msg - elif self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "agent": + elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "agent": async for msg in self._agent_chat_messages(query): yield msg - elif self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "workflow": + elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "workflow": async for msg in self._workflow_messages(query): yield msg else: raise errors.DifyAPIError( - f"不支持的 Dify 应用类型: {self.ap.provider_cfg.data['dify-service-api']['app-type']}" + f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}" ) diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index f05c82e3..68bb2b4f 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -16,14 +16,12 @@ class LocalAgentRunner(runner.RequestRunner): async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行请求 """ - await query.use_model.requester.preprocess(query) - pending_tool_calls = [] req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message] # 首次请求 - msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs) + msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs) yield msg @@ -61,7 +59,7 @@ class LocalAgentRunner(runner.RequestRunner): req_messages.append(err_msg) # 处理完所有调用,再次请求 - msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs) + msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs) yield msg diff --git a/pkg/provider/session/sessionmgr.py b/pkg/provider/session/sessionmgr.py index 00523472..83691e4c 100644 --- a/pkg/provider/session/sessionmgr.py +++ b/pkg/provider/session/sessionmgr.py @@ -41,7 +41,7 @@ class SessionManager: self.session_list.append(session) return session - async def get_conversation(self, session: core_entities.Session) -> core_entities.Conversation: + async def get_conversation(self, query: core_entities.Query, session: core_entities.Session) -> core_entities.Conversation: """获取对话或创建对话""" if not session.conversations: @@ -51,7 +51,9 @@ class SessionManager: conversation = core_entities.Conversation( prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name), messages=[], - use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['model']), + use_llm_model=await self.ap.model_mgr.get_model_by_uuid( + query.pipeline_config['ai']['local-agent']['model'] + ), use_funcs=await self.ap.tool_mgr.get_all_functions( plugin_enabled=True, ),