diff --git a/pkg/api/http/service/pipeline.py b/pkg/api/http/service/pipeline.py index 7920c4c9..72b7daf7 100644 --- a/pkg/api/http/service/pipeline.py +++ b/pkg/api/http/service/pipeline.py @@ -5,10 +5,25 @@ import datetime import sqlalchemy from ....core import app -from ....pipeline import stagemgr from ....entity.persistence import pipeline as persistence_pipeline +default_stage_order = [ + "GroupRespondRuleCheckStage", # 群响应规则检查 + "BanSessionCheckStage", # 封禁会话检查 + "PreContentFilterStage", # 内容过滤前置阶段 + "PreProcessor", # 预处理器 + "ConversationMessageTruncator", # 会话消息截断器 + "RequireRateLimitOccupancy", # 请求速率限制占用 + "MessageProcessor", # 处理器 + "ReleaseRateLimitOccupancy", # 释放速率限制占用 + "PostContentFilterStage", # 内容过滤后置阶段 + "ResponseWrapper", # 响应包装器 + "LongTextProcessStage", # 长文本处理 + "SendResponseBackStage", # 发送响应 +] + + class PipelineService: ap: app.Application @@ -49,7 +64,7 @@ class PipelineService: async def create_pipeline(self, pipeline_data: dict) -> str: pipeline_data['uuid'] = str(uuid.uuid4()) pipeline_data['for_version'] = self.ap.ver_mgr.get_current_version() - pipeline_data['stages'] = stagemgr.stage_order.copy() + pipeline_data['stages'] = default_stage_order.copy() # TODO: 检查pipeline config是否完整 @@ -64,9 +79,12 @@ class PipelineService: return pipeline_data['uuid'] async def update_pipeline(self, pipeline_uuid: str, pipeline_data: dict) -> None: - del pipeline_data['uuid'] - del pipeline_data['for_version'] - del pipeline_data['stages'] + if 'uuid' in pipeline_data: + del pipeline_data['uuid'] + if 'for_version' in pipeline_data: + del pipeline_data['for_version'] + if 'stages' in pipeline_data: + del pipeline_data['stages'] await self.ap.persistence_mgr.execute_async( sqlalchemy.update(persistence_pipeline.LegacyPipeline).where(persistence_pipeline.LegacyPipeline.uuid == pipeline_uuid).values(**pipeline_data) ) diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py index 8d442fdb..ea4e1a9b 100644 --- a/pkg/command/cmdmgr.py +++ b/pkg/command/cmdmgr.py @@ -8,7 +8,7 @@ from . import entities, operator, errors from ..config import manager as cfg_mgr # 引入所有算子以便注册 -from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model +from .operators import func, plugin, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model class CommandManager: diff --git a/pkg/command/operators/default.py b/pkg/command/operators/default.py deleted file mode 100644 index ee46c7d0..00000000 --- a/pkg/command/operators/default.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -import typing -import traceback - -from .. import operator, entities, cmdmgr, errors - - -@operator.operator_class( - name="default", - help="操作情景预设", - usage='!default\n!default set <指定情景预设为默认>' -) -class DefaultOperator(operator.CommandOperator): - - async def execute( - self, - context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - - reply_str = "当前所有情景预设: \n\n" - - for prompt in self.ap.prompt_mgr.get_all_prompts(): - - content = "" - for msg in prompt.messages: - content += f" {msg.readable_str()}\n" - - reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n" - - reply_str += f"当前会话使用的是: {context.session.use_prompt_name}" - - yield entities.CommandReturn(text=reply_str.strip()) - - -@operator.operator_class( - name="set", - help="设置当前会话默认情景预设", - parent_class=DefaultOperator -) -class DefaultSetOperator(operator.CommandOperator): - - async def execute( - self, - context: entities.ExecuteContext - ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - - if len(context.crt_params) == 0: - yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称')) - else: - prompt_name = context.crt_params[0] - - try: - prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name) - if prompt is None: - yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name))) - else: - context.session.use_prompt_name = prompt.name - yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效") - except Exception as e: - traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e))) diff --git a/pkg/core/app.py b/pkg/core/app.py index 0191cc02..bfea617b 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -11,16 +11,14 @@ import os from ..platform import manager as im_mgr from ..provider.session import sessionmgr as llm_session_mgr from ..provider.modelmgr import modelmgr as llm_model_mgr -from ..provider.sysprompt import sysprompt as llm_prompt_mgr from ..provider.tools import toolmgr as llm_tool_mgr -from ..provider import runnermgr from ..config import manager as config_mgr from ..config import settings as settings_mgr from ..audit.center import v2 as center_mgr from ..command import cmdmgr from ..plugin import manager as plugin_mgr from ..pipeline import pool -from ..pipeline import controller, stagemgr, pipelinemgr +from ..pipeline import controller, pipelinemgr from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr from ..persistence import mgr as persistencemgr from ..api.http.controller import main as http_controller @@ -53,12 +51,9 @@ class Application: model_mgr: llm_model_mgr.ModelManager = None - prompt_mgr: llm_prompt_mgr.PromptManager = None - + # TODO 移动到 pipeline 里 tool_mgr: llm_tool_mgr.ToolManager = None - runner_mgr: runnermgr.RunnerManager = None - settings_mgr: settings_mgr.SettingsManager = None # ======= 配置管理器 ======= @@ -100,8 +95,6 @@ class Application: ctrl: controller.Controller = None - stage_mgr: stagemgr.StageManager = None - pipeline_mgr: pipelinemgr.PipelineManager = None ver_mgr: version_mgr.VersionManager = None @@ -232,16 +225,8 @@ class Application: await llm_session_mgr_inst.initialize() self.sess_mgr = llm_session_mgr_inst - llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(self) - await llm_prompt_mgr_inst.initialize() - self.prompt_mgr = llm_prompt_mgr_inst - llm_tool_mgr_inst = llm_tool_mgr.ToolManager(self) await llm_tool_mgr_inst.initialize() self.tool_mgr = llm_tool_mgr_inst - - runner_mgr_inst = runnermgr.RunnerManager(self) - await runner_mgr_inst.initialize() - self.runner_mgr = runner_mgr_inst case _: pass \ No newline at end of file diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 71ec995b..1753495b 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -8,8 +8,7 @@ import asyncio import pydantic.v1 as pydantic from ..provider import entities as llm_entities -from ..provider.modelmgr import entities -from ..provider.sysprompt import entities as sysprompt_entities +from ..provider.modelmgr import entities, modelmgr, requester from ..provider.tools import entities as tools_entities from ..platform import adapter as msadapter from ..platform.types import message as platform_message @@ -57,6 +56,15 @@ class Query(pydantic.BaseModel): message_chain: platform_message.MessageChain """消息链,platform收到的原始消息链""" + bot_uuid: typing.Optional[str] = None + """机器人UUID。""" + + pipeline_uuid: typing.Optional[str] = None + """流水线UUID。""" + + pipeline_config: typing.Optional[dict[str, typing.Any]] = None + """流水线配置,由 Pipeline 在运行开始时设置。""" + adapter: msadapter.MessagePlatformAdapter """消息平台适配器对象,单个app中可能启用了多个消息平台适配器,此对象表明发起此query的适配器""" @@ -66,7 +74,7 @@ class Query(pydantic.BaseModel): messages: typing.Optional[list[llm_entities.Message]] = [] """历史消息列表,由前置处理器阶段设置""" - prompt: typing.Optional[sysprompt_entities.Prompt] = None + prompt: typing.Optional[llm_entities.Prompt] = None """情景预设内容,由前置处理器阶段设置""" user_message: typing.Optional[llm_entities.Message] = None @@ -75,8 +83,8 @@ class Query(pydantic.BaseModel): variables: typing.Optional[dict[str, typing.Any]] = None """变量,由前置处理器阶段设置。在prompt中嵌入或由 Runner 传递到 LLMOps 平台。""" - use_model: typing.Optional[entities.LLMModelInfo] = None - """使用的模型,由前置处理器阶段设置""" + use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None + """使用的对话模型,由前置处理器阶段设置""" use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None """使用的函数,由前置处理器阶段设置""" @@ -88,7 +96,7 @@ class Query(pydantic.BaseModel): """回复消息链,从resp_messages包装而得""" # ======= 内部保留 ======= - current_stage: "pkg.pipeline.stagemgr.StageInstContainer" = None + current_stage: "pkg.pipeline.pipelinemgr.StageInstContainer" = None """当前所处阶段""" class Config: @@ -118,7 +126,7 @@ class Query(pydantic.BaseModel): class Conversation(pydantic.BaseModel): """对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation,但只有一个当前使用的 Conversation""" - prompt: sysprompt_entities.Prompt + prompt: llm_entities.Prompt messages: list[llm_entities.Message] @@ -126,13 +134,16 @@ class Conversation(pydantic.BaseModel): update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) - use_model: entities.LLMModelInfo + use_llm_model: requester.RuntimeLLMModel use_funcs: typing.Optional[list[tools_entities.LLMFunction]] uuid: typing.Optional[str] = None """该对话的 uuid,在创建时不会自动生成。而是当使用 Dify API 等由外部管理对话信息的服务时,用于绑定外部的会话。具体如何使用,取决于 Runner。""" + class Config: + arbitrary_types_allowed = True + class Session(pydantic.BaseModel): """会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}""" diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index 0bd0d8a5..fc049d9c 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -6,14 +6,12 @@ from .. import stage, app from ...utils import version, proxy, announce, platform from ...audit.center import v2 as center_v2 from ...audit import identifier -from ...pipeline import pool, controller, stagemgr, pipelinemgr +from ...pipeline import pool, controller, pipelinemgr from ...plugin import manager as plugin_mgr from ...command import cmdmgr from ...provider.session import sessionmgr as llm_session_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr -from ...provider.sysprompt import sysprompt as llm_prompt_mgr from ...provider.tools import toolmgr as llm_tool_mgr -from ...provider import runnermgr from ...platform import manager as im_mgr from ...persistence import mgr as persistencemgr from ...api.http.controller import main as http_controller @@ -61,10 +59,7 @@ class BuildAppStage(stage.BootingStage): }, runtime_info={ "admin_id": "{}".format(ap.system_cfg.data["admin-sessions"]), - "msg_source": str([ - adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown' - for adapter_cfg in ap.platform_cfg.data['platform-adapters'] if adapter_cfg['enable'] - ]), + "msg_source": str([]), }, ) ap.ctr_mgr = center_v2_api @@ -99,26 +94,14 @@ class BuildAppStage(stage.BootingStage): await llm_session_mgr_inst.initialize() ap.sess_mgr = llm_session_mgr_inst - llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap) - await llm_prompt_mgr_inst.initialize() - ap.prompt_mgr = llm_prompt_mgr_inst - llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap) await llm_tool_mgr_inst.initialize() ap.tool_mgr = llm_tool_mgr_inst - runner_mgr_inst = runnermgr.RunnerManager(ap) - await runner_mgr_inst.initialize() - ap.runner_mgr = runner_mgr_inst - im_mgr_inst = im_mgr.PlatformManager(ap=ap) await im_mgr_inst.initialize() ap.platform_mgr = im_mgr_inst - stage_mgr = stagemgr.StageManager(ap) - await stage_mgr.initialize() - ap.stage_mgr = stage_mgr - pipeline_mgr = pipelinemgr.PipelineManager(ap) await pipeline_mgr.initialize() ap.pipeline_mgr = pipeline_mgr diff --git a/pkg/core/taskmgr.py b/pkg/core/taskmgr.py index 2c029c03..d5887019 100644 --- a/pkg/core/taskmgr.py +++ b/pkg/core/taskmgr.py @@ -233,3 +233,10 @@ class AsyncTaskManager: if not wrapper.task.done() and scope in wrapper.scopes: wrapper.task.cancel() + + def cancel_task(self, task_id: int): + for wrapper in self.tasks: + if wrapper.id == task_id: + if not wrapper.task.done(): + wrapper.task.cancel() + return diff --git a/pkg/pipeline/bansess/bansess.py b/pkg/pipeline/bansess/bansess.py index 9c041385..38fb9794 100644 --- a/pkg/pipeline/bansess/bansess.py +++ b/pkg/pipeline/bansess/bansess.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from .. import stage, entities, stagemgr +from .. import stage, entities from ...core import entities as core_entities from ...config import manager as cfg_mgr @@ -13,7 +13,7 @@ class BanSessionCheckStage(stage.PipelineStage): 仅检查query中群号或个人号是否在访问控制列表中。 """ - async def initialize(self): + async def initialize(self, pipeline_config: dict): pass async def process( @@ -24,9 +24,9 @@ class BanSessionCheckStage(stage.PipelineStage): found = False - mode = self.ap.pipeline_cfg.data['access-control']['mode'] + mode = query.pipeline_config['trigger']['access-control']['mode'] - sess_list = self.ap.pipeline_cfg.data['access-control'][mode] + sess_list = query.pipeline_config['trigger']['access-control'][mode] if (query.launcher_type.value == 'group' and 'group_*' in sess_list) \ or (query.launcher_type.value == 'person' and 'person_*' in sess_list): diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index f7376b61..dbf7c52e 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -2,7 +2,7 @@ from __future__ import annotations from ...core import app -from .. import stage, entities, stagemgr +from .. import stage, entities from ...core import entities as core_entities from ...config import manager as cfg_mgr from . import filter as filter_model, entities as filter_entities @@ -35,17 +35,18 @@ class ContentFilterStage(stage.PipelineStage): self.filter_chain = [] super().__init__(ap) - async def initialize(self): + async def initialize(self, pipeline_config: dict): filters_required = [ "content-ignore", ] - if self.ap.pipeline_cfg.data['check-sensitive-words']: + if pipeline_config['safety']['content-filter']['check-sensitive-words']: filters_required.append("ban-word-filter") - if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']: - filters_required.append("baidu-cloud-examine") + # TODO revert it + # if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']: + # filters_required.append("baidu-cloud-examine") for filter in filter_model.preregistered_filters: if filter.name in filters_required: @@ -65,7 +66,7 @@ class ContentFilterStage(stage.PipelineStage): 只要有一个不通过就不放行,只放行 PASS 的消息 """ - if not self.ap.pipeline_cfg.data['income-msg-check']: + if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg': return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query @@ -73,7 +74,7 @@ class ContentFilterStage(stage.PipelineStage): else: for filter in self.filter_chain: if filter_entities.EnableStage.PRE in filter.enable_stages: - result = await filter.process(message) + result = await filter.process(query, message) if result.level in [ filter_entities.ResultLevel.BLOCK, @@ -105,7 +106,7 @@ class ContentFilterStage(stage.PipelineStage): """请求llm后处理响应 只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter """ - if message is None: + if query.pipeline_config['safety']['content-filter']['scope'] == 'income-msg': return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query @@ -114,7 +115,7 @@ class ContentFilterStage(stage.PipelineStage): message = message.strip() for filter in self.filter_chain: if filter_entities.EnableStage.POST in filter.enable_stages: - result = await filter.process(message) + result = await filter.process(query, message) if result.level == filter_entities.ResultLevel.BLOCK: return entities.StageProcessResult( diff --git a/pkg/pipeline/cntfilter/filter.py b/pkg/pipeline/cntfilter/filter.py index 8eceb877..970e11f1 100644 --- a/pkg/pipeline/cntfilter/filter.py +++ b/pkg/pipeline/cntfilter/filter.py @@ -3,7 +3,7 @@ from __future__ import annotations import abc import typing -from ...core import app +from ...core import app, entities as core_entities from . import entities from ...provider import entities as llm_entities @@ -64,7 +64,7 @@ class ContentFilter(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def process(self, message: str=None, image_url=None) -> entities.FilterResult: + async def process(self, query: core_entities.Query, message: str=None, image_url=None) -> entities.FilterResult: """处理消息 分为前后阶段,具体取决于 enable_stages 的值。 diff --git a/pkg/pipeline/cntfilter/filters/baiduexamine.py b/pkg/pipeline/cntfilter/filters/baiduexamine.py index 8c5b77cd..800f0099 100644 --- a/pkg/pipeline/cntfilter/filters/baiduexamine.py +++ b/pkg/pipeline/cntfilter/filters/baiduexamine.py @@ -4,6 +4,7 @@ import aiohttp from .. import entities from .. import filter as filter_model +from ....core import entities as core_entities BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}" @@ -26,7 +27,7 @@ class BaiduCloudExamine(filter_model.ContentFilter): ) as resp: return (await resp.json())['access_token'] - async def process(self, message: str) -> entities.FilterResult: + async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: async with aiohttp.ClientSession() as session: async with session.post( diff --git a/pkg/pipeline/cntfilter/filters/banwords.py b/pkg/pipeline/cntfilter/filters/banwords.py index 1430c2ed..cd3d412c 100644 --- a/pkg/pipeline/cntfilter/filters/banwords.py +++ b/pkg/pipeline/cntfilter/filters/banwords.py @@ -3,7 +3,7 @@ import re from .. import filter as filter_model from .. import entities -from ....config import manager as cfg_mgr +from ....core import entities as core_entities @filter_model.filter_class("ban-word-filter") @@ -13,7 +13,7 @@ class BanWordFilter(filter_model.ContentFilter): async def initialize(self): pass - async def process(self, message: str) -> entities.FilterResult: + async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: found = False for word in self.ap.sensitive_meta.data['words']: diff --git a/pkg/pipeline/cntfilter/filters/cntignore.py b/pkg/pipeline/cntfilter/filters/cntignore.py index 781f6397..381d5c51 100644 --- a/pkg/pipeline/cntfilter/filters/cntignore.py +++ b/pkg/pipeline/cntfilter/filters/cntignore.py @@ -3,6 +3,7 @@ import re from .. import entities from .. import filter as filter_model +from ....core import entities as core_entities @filter_model.filter_class("content-ignore") @@ -15,9 +16,9 @@ class ContentIgnore(filter_model.ContentFilter): entities.EnableStage.PRE, ] - async def process(self, message: str) -> entities.FilterResult: - if 'prefix' in self.ap.pipeline_cfg.data['ignore-rules']: - for rule in self.ap.pipeline_cfg.data['ignore-rules']['prefix']: + async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: + if 'prefix' in query.pipeline_config['trigger']['ignore-rules']: + for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']: if message.startswith(rule): return entities.FilterResult( level=entities.ResultLevel.BLOCK, @@ -26,8 +27,8 @@ class ContentIgnore(filter_model.ContentFilter): console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息' ) - if 'regexp' in self.ap.pipeline_cfg.data['ignore-rules']: - for rule in self.ap.pipeline_cfg.data['ignore-rules']['regexp']: + if 'regexp' in query.pipeline_config['trigger']['ignore-rules']: + for rule in query.pipeline_config['trigger']['ignore-rules']['regexp']: if re.search(rule, message): return entities.FilterResult( level=entities.ResultLevel.BLOCK, diff --git a/pkg/pipeline/controller.py b/pkg/pipeline/controller.py index 807bec05..64d4e8f4 100644 --- a/pkg/pipeline/controller.py +++ b/pkg/pipeline/controller.py @@ -50,14 +50,23 @@ class Controller: continue if selected_query: - async def _process_query(selected_query): + + async def _process_query(selected_query: entities.Query): async with self.semaphore: # 总并发上限 - await self.process_query(selected_query) + # find pipeline + # Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one. + # Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected. + bot = await self.ap.platform_mgr.get_bot_by_uuid(selected_query.bot_uuid) + if bot: + pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(bot.bot_entity.use_pipeline_uuid) + if pipeline: + await pipeline.run(selected_query) async with self.ap.query_pool: (await self.ap.sess_mgr.get_session(selected_query)).semaphore.release() # 通知其他协程,有新的请求可以处理了 self.ap.query_pool.condition.notify_all() + self.ap.task_mgr.create_task( _process_query(selected_query), kind="query", @@ -70,127 +79,6 @@ class Controller: self.ap.logger.error(f"控制器循环出错: {e}") self.ap.logger.error(f"Traceback: {traceback.format_exc()}") - async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult): - """检查输出 - """ - if result.user_notice: - # 处理str类型 - - if isinstance(result.user_notice, str): - result.user_notice = platform_message.MessageChain( - platform_message.Plain(result.user_notice) - ) - elif isinstance(result.user_notice, list): - result.user_notice = platform_message.MessageChain( - *result.user_notice - ) - - await self.ap.platform_mgr.send( - query.message_event, - result.user_notice, - query.adapter - ) - if result.debug_notice: - self.ap.logger.debug(result.debug_notice) - if result.console_notice: - self.ap.logger.info(result.console_notice) - if result.error_notice: - self.ap.logger.error(result.error_notice) - - async def _execute_from_stage( - self, - stage_index: int, - query: entities.Query, - ): - """从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。 - - 如何看懂这里为什么这么写? - 去问 GPT-4: - Q1: 现在有一个责任链,其中有多个stage,query对象在其中传递,stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None], - 如果返回的是生成器,需要挨个生成result,检查是否result中是否要求继续,如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了,就继续生成下一个result, - 调用后续的stage,直到该生成器全部生成完。责任链中可能有多个stage会返回生成器 - Q2: 不是这样的,你可能理解有误。如果我们责任链上有这些Stage: - - A B C D E F G - - 如果所有的stage都返回Result,且所有Result都要求继续,那么执行顺序是: - - A B C D E F G - - 现在假设C返回的是AsyncGenerator,那么执行顺序是: - - A B C D E F G C D E F G C D E F G ... - Q3: 但是如果不止一个stage会返回生成器呢? - """ - i = stage_index - - while i < len(self.ap.stage_mgr.stage_containers): - stage_container = self.ap.stage_mgr.stage_containers[i] - - query.current_stage = stage_container # 标记到 Query 对象里 - - result = stage_container.inst.process(query, stage_container.inst_name) - - if isinstance(result, typing.Coroutine): - result = await result - - if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果 - self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}") - await self._check_output(query, result) - - if result.result_type == pipeline_entities.ResultType.INTERRUPT: - self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") - break - elif result.result_type == pipeline_entities.ResultType.CONTINUE: - query = result.new_query - elif isinstance(result, typing.AsyncGenerator): # 生成器 - self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen") - - async for sub_result in result: - self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}") - await self._check_output(query, sub_result) - - if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT: - self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") - break - elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE: - query = sub_result.new_query - await self._execute_from_stage(i + 1, query) - break - - i += 1 - - async def process_query(self, query: entities.Query): - """处理请求 - """ - try: - - # ======== 触发 MessageReceived 事件 ======== - event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived - - event_ctx = await self.ap.plugin_mgr.emit_event( - event=event_type( - launcher_type=query.launcher_type.value, - launcher_id=query.launcher_id, - sender_id=query.sender_id, - message_chain=query.message_chain, - query=query - ) - ) - - if event_ctx.is_prevented_default(): - return - - self.ap.logger.debug(f"Processing query {query}") - - await self._execute_from_stage(0, query) - except Exception as e: - inst_name = query.current_stage.inst_name if query.current_stage else 'unknown' - self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}") - self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") - finally: - self.ap.logger.debug(f"Query {query} processed") - async def run(self): """运行控制器 """ diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index ecb745d0..ac03ad42 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -7,7 +7,7 @@ from PIL import Image, ImageDraw, ImageFont from ...core import app from . import strategy from .strategies import image, forward -from .. import stage, entities, stagemgr +from .. import stage, entities from ...core import entities as core_entities from ...config import manager as cfg_mgr from ...platform.types import message as platform_message @@ -23,8 +23,8 @@ class LongTextProcessStage(stage.PipelineStage): strategy_impl: strategy.LongTextStrategy - async def initialize(self): - config = self.ap.platform_cfg.data['long-text-process'] + async def initialize(self, pipeline_config: dict): + config = pipeline_config['output']['long-text-processing'] if config['strategy'] == 'image': use_font = config['font-path'] try: @@ -42,12 +42,12 @@ class LongTextProcessStage(stage.PipelineStage): else: self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。") - self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward" + pipeline_config['output']['long-text-processing']['strategy'] = "forward" except: traceback.print_exc() self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。".format(use_font)) - self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward" + pipeline_config['output']['long-text-processing']['strategy'] = "forward" for strategy_cls in strategy.preregistered_strategies: if strategy_cls.name == config['strategy']: @@ -69,7 +69,7 @@ class LongTextProcessStage(stage.PipelineStage): if contains_non_plain: self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。") - elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']: + elif len(str(query.resp_message_chain[-1])) > query.pipeline_config['output']['long-text-processing']['threshold']: query.resp_message_chain[-1] = platform_message.MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query)) return entities.StageProcessResult( diff --git a/pkg/pipeline/longtext/strategies/image.py b/pkg/pipeline/longtext/strategies/image.py index b9675074..b30d3a81 100644 --- a/pkg/pipeline/longtext/strategies/image.py +++ b/pkg/pipeline/longtext/strategies/image.py @@ -8,6 +8,7 @@ import re from PIL import Image, ImageDraw, ImageFont +import functools from ....platform.types import message as platform_message from .. import strategy as strategy_model @@ -17,15 +18,18 @@ from ....core import entities as core_entities @strategy_model.strategy_class("image") class Text2ImageStrategy(strategy_model.LongTextStrategy): - text_render_font: ImageFont.FreeTypeFont - async def initialize(self): - self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8") + pass + + @functools.lru_cache(maxsize=16) + def get_font(self, query: core_entities.Query): + return ImageFont.truetype(query.pipeline_config['output']['long-text-processing']['font-path'], 32, encoding="utf-8") async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: img_path = self.text_to_image( text_str=message, - save_as='temp/{}.png'.format(int(time.time())) + save_as='temp/{}.png'.format(int(time.time())), + query=query ) compressed_path, size = self.compress_image( @@ -127,7 +131,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): return outfile, self.get_size(outfile) - def text_to_image(self, text_str: str, save_as="temp.png", width=800): + def text_to_image(self, text_str: str, save_as="temp.png", width=800, query: core_entities.Query = None): text_str = text_str.replace("\t", " ") @@ -142,7 +146,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width)) for line in lines: # 如果长了就分割 - line_width = self.text_render_font.getlength(line) + line_width = self.get_font(query).getlength(line) self.ap.logger.debug("line_width: {}".format(line_width)) if line_width < text_width: final_lines.append(line) diff --git a/pkg/pipeline/msgtrun/msgtrun.py b/pkg/pipeline/msgtrun/msgtrun.py index e56c551f..b3fb593a 100644 --- a/pkg/pipeline/msgtrun/msgtrun.py +++ b/pkg/pipeline/msgtrun/msgtrun.py @@ -1,6 +1,6 @@ from __future__ import annotations -from .. import stage, entities, stagemgr +from .. import stage, entities from ...core import entities as core_entities from . import truncator from .truncators import round @@ -14,8 +14,8 @@ class ConversationMessageTruncator(stage.PipelineStage): """ trun: truncator.Truncator - async def initialize(self): - use_method = self.ap.pipeline_cfg.data['msg-truncate']['method'] + async def initialize(self, pipeline_config: dict): + use_method = "round" for trun in truncator.preregistered_truncators: if trun.name == use_method: diff --git a/pkg/pipeline/msgtrun/truncators/round.py b/pkg/pipeline/msgtrun/truncators/round.py index 646f2856..46fce5f3 100644 --- a/pkg/pipeline/msgtrun/truncators/round.py +++ b/pkg/pipeline/msgtrun/truncators/round.py @@ -12,7 +12,7 @@ class RoundTruncator(truncator.Truncator): async def truncate(self, query: core_entities.Query) -> core_entities.Query: """截断 """ - max_round = self.ap.pipeline_cfg.data['msg-truncate']['round']['max-round'] + max_round = query.pipeline_config['ai']['local-agent']['max-round'] temp_messages = [] diff --git a/pkg/pipeline/pipelinemgr.py b/pkg/pipeline/pipelinemgr.py index a805e5cd..b7eaaab4 100644 --- a/pkg/pipeline/pipelinemgr.py +++ b/pkg/pipeline/pipelinemgr.py @@ -1,12 +1,40 @@ from __future__ import annotations import typing +import traceback import sqlalchemy from ..core import app, entities +from . import entities as pipeline_entities from ..entity.persistence import pipeline as persistence_pipeline -from . import stagemgr, stage +from . import stage +from ..platform.types import message as platform_message, events as platform_events +from ..plugin import events + +from .resprule import resprule +from .bansess import bansess +from .cntfilter import cntfilter +from .process import process +from .longtext import longtext +from .respback import respback +from .wrapper import wrapper +from .preproc import preproc +from .ratelimit import ratelimit +from .msgtrun import msgtrun + + +class StageInstContainer(): + """阶段实例容器 + """ + + inst_name: str + + inst: stage.PipelineStage + + def __init__(self, inst_name: str, inst: stage.PipelineStage): + self.inst_name = inst_name + self.inst = inst class RuntimePipeline: @@ -17,16 +45,146 @@ class RuntimePipeline: pipeline_entity: persistence_pipeline.LegacyPipeline """流水线实体""" - stage_containers: list[stagemgr.StageInstContainer] + stage_containers: list[StageInstContainer] """阶段实例容器""" - def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[stagemgr.StageInstContainer]): + def __init__(self, ap: app.Application, pipeline_entity: persistence_pipeline.LegacyPipeline, stage_containers: list[StageInstContainer]): self.ap = ap self.pipeline_entity = pipeline_entity self.stage_containers = stage_containers - async def run(self): - pass + async def run(self, query: entities.Query): + query.pipeline_config = self.pipeline_entity.config + await self.process_query(query) + + async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult): + """检查输出 + """ + if result.user_notice: + # 处理str类型 + + if isinstance(result.user_notice, str): + result.user_notice = platform_message.MessageChain( + platform_message.Plain(result.user_notice) + ) + elif isinstance(result.user_notice, list): + result.user_notice = platform_message.MessageChain( + *result.user_notice + ) + + if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage): + result.user_notice.insert( + 0, + platform_message.At( + query.message_event.sender.id + ) + ) + + await query.adapter.reply_message( + message_source=query.message_event, + message=result.user_notice, + quote_origin=query.pipeline_config['output']['misc']['quote-origin'] + ) + if result.debug_notice: + self.ap.logger.debug(result.debug_notice) + if result.console_notice: + self.ap.logger.info(result.console_notice) + if result.error_notice: + self.ap.logger.error(result.error_notice) + + async def _execute_from_stage( + self, + stage_index: int, + query: entities.Query, + ): + """从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。 + + 如何看懂这里为什么这么写? + 去问 GPT-4: + Q1: 现在有一个责任链,其中有多个stage,query对象在其中传递,stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None], + 如果返回的是生成器,需要挨个生成result,检查是否result中是否要求继续,如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了,就继续生成下一个result, + 调用后续的stage,直到该生成器全部生成完。责任链中可能有多个stage会返回生成器 + Q2: 不是这样的,你可能理解有误。如果我们责任链上有这些Stage: + + A B C D E F G + + 如果所有的stage都返回Result,且所有Result都要求继续,那么执行顺序是: + + A B C D E F G + + 现在假设C返回的是AsyncGenerator,那么执行顺序是: + + A B C D E F G C D E F G C D E F G ... + Q3: 但是如果不止一个stage会返回生成器呢? + """ + i = stage_index + + while i < len(self.stage_containers): + stage_container = self.stage_containers[i] + + query.current_stage = stage_container # 标记到 Query 对象里 + + result = stage_container.inst.process(query, stage_container.inst_name) + + if isinstance(result, typing.Coroutine): + result = await result + + if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果 + self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}") + await self._check_output(query, result) + + if result.result_type == pipeline_entities.ResultType.INTERRUPT: + self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") + break + elif result.result_type == pipeline_entities.ResultType.CONTINUE: + query = result.new_query + elif isinstance(result, typing.AsyncGenerator): # 生成器 + self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen") + + async for sub_result in result: + self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}") + await self._check_output(query, sub_result) + + if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT: + self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}") + break + elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE: + query = sub_result.new_query + await self._execute_from_stage(i + 1, query) + break + + i += 1 + + async def process_query(self, query: entities.Query): + """处理请求 + """ + try: + + # ======== 触发 MessageReceived 事件 ======== + event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived + + event_ctx = await self.ap.plugin_mgr.emit_event( + event=event_type( + launcher_type=query.launcher_type.value, + launcher_id=query.launcher_id, + sender_id=query.sender_id, + message_chain=query.message_chain, + query=query + ) + ) + + if event_ctx.is_prevented_default(): + return + + self.ap.logger.debug(f"Processing query {query}") + + await self._execute_from_stage(0, query) + except Exception as e: + inst_name = query.current_stage.inst_name if query.current_stage else 'unknown' + self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}") + self.ap.logger.debug(f"Traceback: {traceback.format_exc()}") + finally: + self.ap.logger.debug(f"Query {query} processed") class PipelineManager: @@ -70,12 +228,15 @@ class PipelineManager: pipeline_entity = persistence_pipeline.LegacyPipeline(**pipeline_entity) # initialize stage containers according to pipeline_entity.stages - stage_containers = [] + stage_containers: list[StageInstContainer] = [] for stage_name in pipeline_entity.stages: - stage_containers.append(stagemgr.StageInstContainer( - stage_name=stage_name, - stage_class=self.stage_dict[stage_name] + stage_containers.append(StageInstContainer( + inst_name=stage_name, + inst=self.stage_dict[stage_name](self.ap) )) + + for stage_container in stage_containers: + await stage_container.inst.initialize(pipeline_entity.config) runtime_pipeline = RuntimePipeline(self.ap, pipeline_entity, stage_containers) self.pipelines.append(runtime_pipeline) diff --git a/pkg/pipeline/pool.py b/pkg/pipeline/pool.py index e358d249..df4d0741 100644 --- a/pkg/pipeline/pool.py +++ b/pkg/pipeline/pool.py @@ -28,15 +28,17 @@ class QueryPool: async def add_query( self, + bot_uuid: str, launcher_type: entities.LauncherTypes, launcher_id: typing.Union[int, str], sender_id: typing.Union[int, str], message_event: platform_events.MessageEvent, message_chain: platform_message.MessageChain, - adapter: msadapter.MessagePlatformAdapter + adapter: msadapter.MessagePlatformAdapter, ) -> entities.Query: async with self.condition: query = entities.Query( + bot_uuid=bot_uuid, query_id=self.query_id_counter, launcher_type=launcher_type, launcher_id=launcher_id, diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index 299aea5e..42bb5b4c 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -2,7 +2,7 @@ from __future__ import annotations import datetime -from .. import stage, entities, stagemgr +from .. import stage, entities from ...core import entities as core_entities from ...provider import entities as llm_entities from ...plugin import events @@ -33,16 +33,16 @@ class PreProcessor(stage.PipelineStage): """ session = await self.ap.sess_mgr.get_session(query) - conversation = await self.ap.sess_mgr.get_conversation(session) + conversation = await self.ap.sess_mgr.get_conversation(query, session, query.pipeline_config['ai']['local-agent']['prompt']) # 设置query query.session = session query.prompt = conversation.prompt.copy() query.messages = conversation.messages.copy() - query.use_model = conversation.use_model + query.use_llm_model = conversation.use_llm_model - query.use_funcs = conversation.use_funcs if query.use_model.tool_call_supported else None + query.use_funcs = conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('tool_call') else None query.variables = { "session_id": f"{query.session.launcher_type.value}_{query.session.launcher_id}", @@ -50,8 +50,9 @@ class PreProcessor(stage.PipelineStage): "msg_create_time": int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp()), } - # 检查vision是否启用,没启用就删除所有图片 - if not self.ap.provider_cfg.data['enable-vision'] or (self.ap.provider_cfg.data['runner'] == 'local-agent' and not query.use_model.vision_supported): + # Check if this model supports vision, if not, remove all images + # TODO this checking should be performed in runner, and in this stage, the image should be reserved + if query.pipeline_config['ai']['runner']['runner'] == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'): for msg in query.messages: if isinstance(msg.content, list): for me in msg.content: @@ -69,7 +70,7 @@ class PreProcessor(stage.PipelineStage): ) plain_text += me.text elif isinstance(me, platform_message.Image): - if self.ap.provider_cfg.data['enable-vision'] and (self.ap.provider_cfg.data['runner'] != 'local-agent' or query.use_model.vision_supported): + if query.pipeline_config['ai']['runner']['runner'] != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__('vision'): if me.base64 is not None: content_list.append( llm_entities.ContentElement.from_image_base64(me.base64) diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 83bb3335..9d231dda 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -9,7 +9,9 @@ import json from .. import handler from ... import entities from ....core import entities as core_entities -from ....provider import entities as llm_entities, runnermgr +from ....provider import entities as llm_entities +from ....provider import runner as runner_module +from ....provider.runners import localagent, difysvapi, dashscopeapi from ....plugin import events from ....platform.types import message as platform_message @@ -56,12 +58,6 @@ class ChatMessageHandler(handler.MessageHandler): ) else: - if not self.ap.provider_cfg.data['enable-chat']: - yield entities.StageProcessResult( - result_type=entities.ResultType.INTERRUPT, - new_query=query, - ) - if event_ctx.event.alter is not None: # if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter query.user_message.content = event_ctx.event.alter @@ -72,7 +68,12 @@ class ChatMessageHandler(handler.MessageHandler): try: - runner = self.ap.runner_mgr.get_runner() + for r in runner_module.preregistered_runners: + if r.name == query.pipeline_config["ai"]["runner"]["runner"]: + runner = r(self.ap, query.pipeline_config) + break + else: + raise ValueError(f"未找到请求运行器: {query.pipeline_config['ai']['runner']['runner']}") async for result in runner.run(query): query.resp_messages.append(result) @@ -93,10 +94,12 @@ class ChatMessageHandler(handler.MessageHandler): self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}') + hide_exception_info = query.pipeline_config['output']['misc']['hide-exception'] + yield entities.StageProcessResult( result_type=entities.ResultType.INTERRUPT, new_query=query, - user_notice='请求失败' if self.ap.platform_cfg.data['hide-exception-info'] else f'{e}', + user_notice='请求失败' if hide_exception_info else f'{e}', error_notice=f'{e}', debug_notice=traceback.format_exc() ) diff --git a/pkg/pipeline/process/process.py b/pkg/pipeline/process/process.py index 362ece01..ea4d7e7f 100644 --- a/pkg/pipeline/process/process.py +++ b/pkg/pipeline/process/process.py @@ -4,7 +4,7 @@ from ...core import app, entities as core_entities from . import handler from .handlers import chat, command from .. import entities -from .. import stage, entities, stagemgr +from .. import stage, entities from ...core import entities as core_entities from ...config import manager as cfg_mgr @@ -23,7 +23,7 @@ class Processor(stage.PipelineStage): chat_handler: handler.MessageHandler - async def initialize(self): + async def initialize(self, pipeline_config: dict): self.cmd_handler = command.CommandHandler(self.ap) self.chat_handler = chat.ChatMessageHandler(self.ap) diff --git a/pkg/pipeline/ratelimit/algo.py b/pkg/pipeline/ratelimit/algo.py index 9b418dd2..d9baa801 100644 --- a/pkg/pipeline/ratelimit/algo.py +++ b/pkg/pipeline/ratelimit/algo.py @@ -2,7 +2,7 @@ from __future__ import annotations import abc import typing -from ...core import app +from ...core import app, entities as core_entities preregistered_algos: list[typing.Type[ReteLimitAlgo]] = [] @@ -31,7 +31,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def require_access(self, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool: + async def require_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool: """进入处理流程 这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。 @@ -46,7 +46,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta): raise NotImplementedError @abc.abstractmethod - async def release_access(self, launcher_type: str, launcher_id: typing.Union[int, str]): + async def release_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]): """退出处理流程 Args: diff --git a/pkg/pipeline/ratelimit/algos/fixedwin.py b/pkg/pipeline/ratelimit/algos/fixedwin.py index 3cc1ab94..f17e93b8 100644 --- a/pkg/pipeline/ratelimit/algos/fixedwin.py +++ b/pkg/pipeline/ratelimit/algos/fixedwin.py @@ -3,6 +3,7 @@ import asyncio import time import typing from .. import algo +from ....core import entities as core_entities # 固定窗口算法 class SessionContainer: @@ -30,7 +31,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo): self.containers_lock = asyncio.Lock() self.containers = {} - async def require_access(self, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool: + async def require_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool: # 加锁,找容器 container: SessionContainer = None @@ -47,12 +48,13 @@ class FixedWindowAlgo(algo.ReteLimitAlgo): async with container.wait_lock: # 获取窗口大小和限制 - window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['window-size'] - limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['limit'] + window_size = query.pipeline_config['safety']['rate-limit']['window-length'] + limitation = query.pipeline_config['safety']['rate-limit']['limitation'] - if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']: - window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size'] - limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit'] + # TODO revert it + # if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']: + # window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size'] + # limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit'] # 获取当前时间戳 now = int(time.time()) @@ -65,9 +67,9 @@ class FixedWindowAlgo(algo.ReteLimitAlgo): # 如果访问次数超过了限制 if count >= limitation: - if self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'drop': + if query.pipeline_config['safety']['rate-limit']['strategy'] == 'drop': return False - elif self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'wait': + elif query.pipeline_config['safety']['rate-limit']['strategy'] == 'wait': # 等待下一窗口 await asyncio.sleep(window_size - time.time() % window_size) @@ -84,5 +86,5 @@ class FixedWindowAlgo(algo.ReteLimitAlgo): # 返回True return True - async def release_access(self, launcher_type: str, launcher_id: typing.Union[int, str]): + async def release_access(self, query: core_entities.Query, launcher_type: str, launcher_id: typing.Union[int, str]): pass diff --git a/pkg/pipeline/ratelimit/ratelimit.py b/pkg/pipeline/ratelimit/ratelimit.py index cd39b85c..c74db978 100644 --- a/pkg/pipeline/ratelimit/ratelimit.py +++ b/pkg/pipeline/ratelimit/ratelimit.py @@ -2,7 +2,7 @@ from __future__ import annotations import typing -from .. import entities, stagemgr, stage +from .. import entities, stage from . import algo from .algos import fixedwin from ...core import entities as core_entities @@ -18,9 +18,9 @@ class RateLimit(stage.PipelineStage): algo: algo.ReteLimitAlgo - async def initialize(self): + async def initialize(self, pipeline_config: dict): - algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo'] + algo_name = 'fixwin' algo_class = None @@ -46,6 +46,7 @@ class RateLimit(stage.PipelineStage): """ if stage_inst_name == "RequireRateLimitOccupancy": if await self.algo.require_access( + query, query.launcher_type.value, query.launcher_id, ): @@ -62,6 +63,7 @@ class RateLimit(stage.PipelineStage): ) elif stage_inst_name == "ReleaseRateLimitOccupancy": await self.algo.release_access( + query, query.launcher_type.value, query.launcher_id, ) diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index 08b335d5..8c074d89 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -5,8 +5,10 @@ import asyncio from ...core import app +from ...platform.types import events as platform_events +from ...platform.types import message as platform_message -from .. import stage, entities, stagemgr +from .. import stage, entities from ...core import entities as core_entities from ...config import manager as cfg_mgr @@ -19,8 +21,8 @@ class SendResponseBackStage(stage.PipelineStage): async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: """处理 """ - - random_range = (self.ap.platform_cfg.data['force-delay']['min'], self.ap.platform_cfg.data['force-delay']['max']) + + random_range = (query.pipeline_config['output']['force-delay']['min'], query.pipeline_config['output']['force-delay']['max']) random_delay = random.uniform(*random_range) @@ -31,10 +33,20 @@ class SendResponseBackStage(stage.PipelineStage): await asyncio.sleep(random_delay) - await self.ap.platform_mgr.send( - query.message_event, - query.resp_message_chain[-1], - adapter=query.adapter + if query.pipeline_config['output']['misc']['at-sender'] and isinstance(query.message_event, platform_events.GroupMessage): + query.resp_message_chain[-1].insert( + 0, + platform_message.At( + query.message_event.sender.id + ) + ) + + quote_origin = query.pipeline_config['output']['misc']['quote-origin'] + + await query.adapter.reply_message( + message_source=query.message_event, + message=query.resp_message_chain[-1], + quote_origin=quote_origin ) return entities.StageProcessResult( diff --git a/pkg/pipeline/resprule/resprule.py b/pkg/pipeline/resprule/resprule.py index 77858f0d..08ba49e8 100644 --- a/pkg/pipeline/resprule/resprule.py +++ b/pkg/pipeline/resprule/resprule.py @@ -5,7 +5,7 @@ from ...core import app from . import entities as rule_entities, rule from .rules import atbot, prefix, regexp, random -from .. import stage, entities, stagemgr +from .. import stage, entities from ...core import entities as core_entities from ...config import manager as cfg_mgr @@ -20,7 +20,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage): rule_matchers: list[rule.GroupRespondRule] """检查器实例""" - async def initialize(self): + async def initialize(self, pipeline_config: dict): """初始化检查器 """ @@ -39,12 +39,13 @@ class GroupRespondRuleCheckStage(stage.PipelineStage): new_query=query ) - rules = self.ap.pipeline_cfg.data['respond-rules'] + rules = query.pipeline_config['trigger']['group-respond-rules'] - use_rule = rules['default'] + use_rule = rules - if str(query.launcher_id) in rules: - use_rule = rules[str(query.launcher_id)] + # TODO revert it + # if str(query.launcher_id) in rules: + # use_rule = rules[str(query.launcher_id)] for rule_matcher in self.rule_matchers: # 任意一个匹配就放行 res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query) diff --git a/pkg/pipeline/stage.py b/pkg/pipeline/stage.py index 206f2bdf..859286d9 100644 --- a/pkg/pipeline/stage.py +++ b/pkg/pipeline/stage.py @@ -28,7 +28,7 @@ class PipelineStage(metaclass=abc.ABCMeta): def __init__(self, ap: app.Application): self.ap = ap - async def initialize(self): + async def initialize(self, pipeline_config: dict): """初始化 """ pass diff --git a/pkg/pipeline/stagemgr.py b/pkg/pipeline/stagemgr.py deleted file mode 100644 index 19fce2d6..00000000 --- a/pkg/pipeline/stagemgr.py +++ /dev/null @@ -1,71 +0,0 @@ -from __future__ import annotations - -from ..core import app -from . import stage -from .resprule import resprule -from .bansess import bansess -from .cntfilter import cntfilter -from .process import process -from .longtext import longtext -from .respback import respback -from .wrapper import wrapper -from .preproc import preproc -from .ratelimit import ratelimit -from .msgtrun import msgtrun - - -# 请求处理阶段顺序 -stage_order = [ - "GroupRespondRuleCheckStage", # 群响应规则检查 - "BanSessionCheckStage", # 封禁会话检查 - "PreContentFilterStage", # 内容过滤前置阶段 - "PreProcessor", # 预处理器 - "ConversationMessageTruncator", # 会话消息截断器 - "RequireRateLimitOccupancy", # 请求速率限制占用 - "MessageProcessor", # 处理器 - "ReleaseRateLimitOccupancy", # 释放速率限制占用 - "PostContentFilterStage", # 内容过滤后置阶段 - "ResponseWrapper", # 响应包装器 - "LongTextProcessStage", # 长文本处理 - "SendResponseBackStage", # 发送响应 -] - - -class StageInstContainer(): - """阶段实例容器 - """ - - inst_name: str - - inst: stage.PipelineStage - - def __init__(self, inst_name: str, inst: stage.PipelineStage): - self.inst_name = inst_name - self.inst = inst - - -class StageManager: - ap: app.Application - - stage_containers: list[StageInstContainer] - - def __init__(self, ap: app.Application): - self.ap = ap - - self.stage_containers = [] - - async def initialize(self): - """初始化 - """ - - for name, cls in stage.preregistered_stages.items(): - self.stage_containers.append(StageInstContainer( - inst_name=name, - inst=cls(self.ap) - )) - - for stage_containers in self.stage_containers: - await stage_containers.inst.initialize() - - # 按照 stage_order 排序 - self.stage_containers.sort(key=lambda x: stage_order.index(x.inst_name)) diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py index a06e4a80..6b12ca65 100644 --- a/pkg/pipeline/wrapper/wrapper.py +++ b/pkg/pipeline/wrapper/wrapper.py @@ -5,7 +5,7 @@ import typing from ...core import app, entities as core_entities from .. import entities -from .. import stage, entities, stagemgr +from .. import stage, entities from ...core import entities as core_entities from ...config import manager as cfg_mgr from ...plugin import events @@ -22,7 +22,7 @@ class ResponseWrapper(stage.PipelineStage): - resp_message_chain """ - async def initialize(self): + async def initialize(self, pipeline_config: dict): pass async def process( @@ -110,7 +110,7 @@ class ResponseWrapper(stage.PipelineStage): query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)])) - if self.ap.platform_cfg.data['track-function-calls']: + if query.pipeline_config['output']['misc']['track-function-calls']: event_ctx = await self.ap.plugin_mgr.emit_event( event=events.NormalMessageResponded( diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index 210ee9ad..360f7588 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -50,6 +50,41 @@ class RuntimeBot: self.adapter = adapter self.task_context = taskmgr.TaskContext() + async def initialize(self): + + async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessagePlatformAdapter): + + await self.ap.query_pool.add_query( + bot_uuid=self.bot_entity.uuid, + launcher_type=core_entities.LauncherTypes.PERSON, + launcher_id=event.sender.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain, + adapter=adapter, + ) + + async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessagePlatformAdapter): + + await self.ap.query_pool.add_query( + bot_uuid=self.bot_entity.uuid, + launcher_type=core_entities.LauncherTypes.GROUP, + launcher_id=event.group.id, + sender_id=event.sender.id, + message_event=event, + message_chain=event.message_chain, + adapter=adapter, + ) + + self.adapter.register_listener( + platform_events.FriendMessage, + on_friend_message + ) + self.adapter.register_listener( + platform_events.GroupMessage, + on_group_message + ) + async def run(self): async def exception_wrapper(): @@ -78,14 +113,16 @@ class RuntimeBot: async def shutdown(self): await self.adapter.kill() + self.ap.task_mgr.cancel_task(self.task_wrapper.id) + # 控制QQ消息输入输出的类 class PlatformManager: # adapter: msadapter.MessageSourceAdapter = None - adapters: list[msadapter.MessagePlatformAdapter] = [] + adapters: list[msadapter.MessagePlatformAdapter] = [] # deprecated - message_platform_adapter_components: list[engine.Component] = [] + message_platform_adapter_components: list[engine.Component] = [] # deprecated # ====== 4.0 ====== ap: app.Application = None @@ -135,49 +172,20 @@ class PlatformManager: bot_entity = persistence_bot.Bot(**bot_entity._mapping) elif isinstance(bot_entity, dict): bot_entity = persistence_bot.Bot(**bot_entity) - - async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessagePlatformAdapter): - - await self.ap.query_pool.add_query( - launcher_type=core_entities.LauncherTypes.PERSON, - launcher_id=event.sender.id, - sender_id=event.sender.id, - message_event=event, - message_chain=event.message_chain, - adapter=adapter - ) - - async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessagePlatformAdapter): - - await self.ap.query_pool.add_query( - launcher_type=core_entities.LauncherTypes.GROUP, - launcher_id=event.group.id, - sender_id=event.sender.id, - message_event=event, - message_chain=event.message_chain, - adapter=adapter - ) adapter_inst = self.adapter_dict[bot_entity.adapter]( bot_entity.adapter_config, self.ap ) - adapter_inst.register_listener( - platform_events.FriendMessage, - on_friend_message - ) - adapter_inst.register_listener( - platform_events.GroupMessage, - on_group_message - ) - runtime_bot = RuntimeBot( ap=self.ap, bot_entity=bot_entity, adapter=adapter_inst ) + await runtime_bot.initialize() + self.bots.append(runtime_bot) return runtime_bot @@ -209,50 +217,36 @@ class PlatformManager: return None async def write_back_config(self, adapter_name: str, adapter_inst: msadapter.MessagePlatformAdapter, config: dict): - index = -2 + # index = -2 - for i, adapter in enumerate(self.adapters): - if adapter == adapter_inst: - index = i - break + # for i, adapter in enumerate(self.adapters): + # if adapter == adapter_inst: + # index = i + # break - if index == -2: - raise Exception('平台适配器未找到') + # if index == -2: + # raise Exception('平台适配器未找到') - # 只修改启用的适配器 - real_index = -1 + # # 只修改启用的适配器 + # real_index = -1 - for i, adapter in enumerate(self.ap.platform_cfg.data['platform-adapters']): - if adapter['enable']: - index -= 1 - if index == -1: - real_index = i - break + # for i, adapter in enumerate(self.ap.platform_cfg.data['platform-adapters']): + # if adapter['enable']: + # index -= 1 + # if index == -1: + # real_index = i + # break - new_cfg = { - 'adapter': adapter_name, - 'enable': True, - **config - } - self.ap.platform_cfg.data['platform-adapters'][real_index] = new_cfg - await self.ap.platform_cfg.dump_config() + # new_cfg = { + # 'adapter': adapter_name, + # 'enable': True, + # **config + # } + # self.ap.platform_cfg.data['platform-adapters'][real_index] = new_cfg + # await self.ap.platform_cfg.dump_config() - async def send(self, event: platform_events.MessageEvent, msg: platform_message.MessageChain, adapter: msadapter.MessagePlatformAdapter): - - if self.ap.platform_cfg.data['at-sender'] and isinstance(event, platform_events.GroupMessage): - - msg.insert( - 0, - platform_message.At( - event.sender.id - ) - ) - - await adapter.reply_message( - event, - msg, - quote_origin=True if self.ap.platform_cfg.data['quote-origin'] else False - ) + # TODO implement this + pass async def run(self): # This method will only be called when the application launching @@ -264,4 +258,4 @@ class PlatformManager: for bot in self.bots: if bot.enable: await bot.shutdown() - self.ap.task_mgr.cancel_by_scope(core_entities.LifecycleControlScope.PLATFORM) \ No newline at end of file + self.ap.task_mgr.cancel_by_scope(core_entities.LifecycleControlScope.PLATFORM) diff --git a/pkg/platform/sources/aiocqhttp.py b/pkg/platform/sources/aiocqhttp.py index af14372a..9149e427 100644 --- a/pkg/platform/sources/aiocqhttp.py +++ b/pkg/platform/sources/aiocqhttp.py @@ -238,4 +238,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter): await self.bot._server_app.run_task(**self.config) async def kill(self) -> bool: + # Current issue: existing connection will not be closed + # self.should_shutdown = True return False diff --git a/pkg/platform/sources/qqbotpy.yaml b/pkg/platform/sources/qqbotpy.yaml index 79653194..a7913042 100644 --- a/pkg/platform/sources/qqbotpy.yaml +++ b/pkg/platform/sources/qqbotpy.yaml @@ -28,9 +28,11 @@ spec: label: en_US: Intents zh_CN: 权限 - type: array[string] + type: array required: true default: [] + items: + type: string execution: python: path: ./qqbotpy.py diff --git a/pkg/plugin/context.py b/pkg/plugin/context.py index 76a49bf4..7a9be2a1 100644 --- a/pkg/plugin/context.py +++ b/pkg/plugin/context.py @@ -222,10 +222,10 @@ class EventContext: Args: message_chain (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链 """ - await self.host.ap.platform_mgr.send( - event=self.event.query.message_event, - msg=message_chain, - adapter=self.event.query.adapter, + # TODO 添加 at_sender 和 quote_origin 参数 + await self.event.query.adapter.reply_message( + message_source=self.event.query.message_event, + message=message_chain ) async def send_message( diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index dce55fd5..0fb75f80 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -4,6 +4,8 @@ import typing import enum import pydantic.v1 as pydantic +from pkg.provider import entities + from ..platform.types import message as platform_message @@ -124,3 +126,13 @@ class Message(pydantic.BaseModel): mc.insert(0, platform_message.Plain(prefix_text)) return platform_message.MessageChain(mc) + + +class Prompt(pydantic.BaseModel): + """供AI使用的Prompt""" + + name: str + """名称""" + + messages: list[entities.Message] + """消息列表""" diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index 7db7a040..41e97f3e 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -2,6 +2,7 @@ from __future__ import annotations import typing import sqlalchemy +import pydantic.v1 as pydantic from . import entities, requester from ...core import app @@ -16,23 +17,6 @@ from .requesters import bailianchatcmpl, chatcmpl, anthropicmsgs, moonshotchatcm FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list" -class RuntimeLLMModel: - """运行时模型""" - - model_entity: persistence_model.LLMModel - """模型数据""" - - token_mgr: token.TokenManager - """api key管理器""" - - requester: requester.LLMAPIRequester - """请求器实例""" - - def __init__(self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, requester: requester.LLMAPIRequester): - self.model_entity = model_entity - self.token_mgr = token_mgr - self.requester = requester - class ModelManager: """模型管理器""" @@ -47,7 +31,7 @@ class ModelManager: ap: app.Application - llm_models: list[RuntimeLLMModel] + llm_models: list[requester.RuntimeLLMModel] requester_components: list[engine.Component] @@ -99,16 +83,20 @@ class ModelManager: elif isinstance(model_info, dict): model_info = persistence_model.LLMModel(**model_info) - runtime_llm_model = RuntimeLLMModel( + requester_inst = self.requester_dict[model_info.requester]( + ap=self.ap, + config=model_info.requester_config + ) + + await requester_inst.initialize() + + runtime_llm_model = requester.RuntimeLLMModel( model_entity=model_info, token_mgr=token.TokenManager( name=model_info.uuid, tokens=model_info.api_keys, ), - requester=self.requester_dict[model_info.requester]( - ap=self.ap, - config=model_info.requester_config - ) + requester=requester_inst ) self.llm_models.append(runtime_llm_model) diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 7f13c58b..5ea8d23f 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -6,8 +6,27 @@ import typing from ...core import app from ...core import entities as core_entities from .. import entities as llm_entities -from . import entities as modelmgr_entities from ..tools import entities as tools_entities +from ...entity.persistence import model as persistence_model +from . import token + + +class RuntimeLLMModel: + """运行时模型""" + + model_entity: persistence_model.LLMModel + """模型数据""" + + token_mgr: token.TokenManager + """api key管理器""" + + requester: LLMAPIRequester + """请求器实例""" + + def __init__(self, model_entity: persistence_model.LLMModel, token_mgr: token.TokenManager, requester: LLMAPIRequester): + self.model_entity = model_entity + self.token_mgr = token_mgr + self.requester = requester class LLMAPIRequester(metaclass=abc.ABCMeta): @@ -31,21 +50,11 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): async def initialize(self): pass - async def preprocess( - self, - query: core_entities.Query, - ): - """预处理 - - 在这里处理特定API对Query对象的兼容性问题。 - """ - pass - @abc.abstractmethod - async def call( + async def invoke_llm( self, query: core_entities.Query, - model: modelmgr_entities.LLMModelInfo, + model: RuntimeLLMModel, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, @@ -53,7 +62,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): """调用API Args: - model (modelmgr_entities.LLMModelInfo): 使用的模型信息 + model (RuntimeLLMModel): 使用的模型信息 messages (typing.List[llm_entities.Message]): 消息对象列表 funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None. extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.py b/pkg/provider/modelmgr/requesters/anthropicmsgs.py index 937f5107..7edc4405 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.py +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.py @@ -24,16 +24,16 @@ class AnthropicMessages(requester.LLMAPIRequester): client: anthropic.AsyncAnthropic default_config: dict[str, typing.Any] = { - 'base-url': 'https://api.anthropic.com/v1', + 'base_url': 'https://api.anthropic.com/v1', 'timeout': 120, } async def initialize(self): httpx_client = anthropic._base_client.AsyncHttpxClientWrapper( - base_url=self.ap.provider_cfg.data['requester']['anthropic-messages']['base-url'], + base_url=self.requester_cfg['base_url'], # cast to a valid type because mypy doesn't understand our type narrowing - timeout=typing.cast(httpx.Timeout, self.ap.provider_cfg.data['requester']['anthropic-messages']['timeout']), + timeout=typing.cast(httpx.Timeout, self.requester_cfg['timeout']), limits=anthropic._constants.DEFAULT_CONNECTION_LIMITS, follow_redirects=True, trust_env=True, @@ -44,17 +44,18 @@ class AnthropicMessages(requester.LLMAPIRequester): http_client=httpx_client, ) - async def call( + async def invoke_llm( self, query: core_entities.Query, - model: entities.LLMModelInfo, + model: requester.RuntimeLLMModel, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: self.client.api_key = model.token_mgr.get_token() - args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy() - args["model"] = model.name if model.model_name is None else model.model_name + args = extra_args.copy() + args["model"] = model.model_entity.name # 处理消息 diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml b/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml index 80380857..6d1a53cf 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: Anthropic spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/bailianchatcmpl.py b/pkg/provider/modelmgr/requesters/bailianchatcmpl.py index ce718c4c..e20e3376 100644 --- a/pkg/provider/modelmgr/requesters/bailianchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/bailianchatcmpl.py @@ -14,6 +14,6 @@ class BailianChatCompletions(chatcmpl.OpenAIChatCompletions): client: openai.AsyncClient default_config: dict[str, typing.Any] = { - 'base-url': 'https://dashscope.aliyuncs.com/compatible-mode/v1', + 'base_url': 'https://dashscope.aliyuncs.com/compatible-mode/v1', 'timeout': 120, } diff --git a/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml b/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml index 74d197ca..136e903f 100644 --- a/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/bailianchatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: 阿里云百炼 spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index 7cf255c0..b59ab42d 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -26,7 +26,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): client: openai.AsyncClient default_config: dict[str, typing.Any] = { - "base-url": "https://api.openai.com/v1", + "base_url": "https://api.openai.com/v1", "timeout": 120, } @@ -34,7 +34,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): self.client = openai.AsyncClient( api_key="", - base_url=self.requester_cfg["base-url"], + base_url=self.requester_cfg["base_url"], timeout=self.requester_cfg["timeout"], http_client=httpx.AsyncClient( trust_env=True, timeout=self.requester_cfg["timeout"] @@ -51,7 +51,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): self, chat_completion: chat_completion.ChatCompletion, ) -> llm_entities.Message: - chatcmpl_message = chat_completion.choices[0].message.dict() + chatcmpl_message = chat_completion.choices[0].message.model_dump() # 确保 role 字段存在且不为 None if "role" not in chatcmpl_message or chatcmpl_message["role"] is None: @@ -65,16 +65,14 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): self, query: core_entities.Query, req_messages: list[dict], - use_model: entities.LLMModelInfo, + use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, # TODO: 所有的args都改为从此参数读取 ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() - args = self.requester_cfg["args"].copy() - args["model"] = ( - use_model.name if use_model.model_name is None else use_model.model_name - ) + args = extra_args.copy() + args["model"] = use_model.model_entity.name if use_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) @@ -104,10 +102,10 @@ class OpenAIChatCompletions(requester.LLMAPIRequester): return message - async def call( + async def invoke_llm( self, query: core_entities.Query, - model: entities.LLMModelInfo, + model: requester.RuntimeLLMModel, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.yaml b/pkg/provider/modelmgr/requesters/chatcmpl.yaml index fe4d3cb5..a67e23d1 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/chatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: OpenAI spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py index e82d0d81..ee17ac05 100644 --- a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py @@ -13,7 +13,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): """Deepseek ChatCompletion API 请求器""" default_config: dict[str, typing.Any] = { - 'base-url': 'https://api.deepseek.com', + 'base_url': 'https://api.deepseek.com', 'timeout': 120, } @@ -21,14 +21,14 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): self, query: core_entities.Query, req_messages: list[dict], - use_model: entities.LLMModelInfo, + use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() - args = self.requester_cfg['args'].copy() - args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + args = extra_args.copy() + args["model"] = use_model.model_entity.name if use_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) diff --git a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml index 2ef91aa2..9890e21e 100644 --- a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: 深度求索 spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py index dec0b8d1..35052682 100644 --- a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.py @@ -18,7 +18,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): """Gitee AI ChatCompletions API 请求器""" default_config: dict[str, typing.Any] = { - 'base-url': 'https://ai.gitee.com/v1', + 'base_url': 'https://ai.gitee.com/v1', 'timeout': 120, } @@ -26,14 +26,14 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): self, query: core_entities.Query, req_messages: list[dict], - use_model: entities.LLMModelInfo, + use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() - args = self.requester_cfg['args'].copy() - args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + args = extra_args.copy() + args["model"] = use_model.model_entity.name if use_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) diff --git a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml index 11f7e06e..3e4efd61 100644 --- a/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/giteeaichatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: Gitee AI spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py index c00be372..6be76051 100644 --- a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py +++ b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.py @@ -14,6 +14,6 @@ class LmStudioChatCompletions(chatcmpl.OpenAIChatCompletions): client: openai.AsyncClient default_config: dict[str, typing.Any] = { - 'base-url': 'http://127.0.0.1:1234/v1', + 'base_url': 'http://127.0.0.1:1234/v1', 'timeout': 120, } diff --git a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml index 959d4151..219e5839 100644 --- a/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/lmstudiochatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: LM Studio spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py index 3cbe8837..1ef7d9c9 100644 --- a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.py @@ -15,7 +15,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): """Moonshot ChatCompletion API 请求器""" default_config: dict[str, typing.Any] = { - 'base-url': 'https://api.moonshot.cn/v1', + 'base_url': 'https://api.moonshot.cn/v1', 'timeout': 120, } @@ -23,14 +23,14 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): self, query: core_entities.Query, req_messages: list[dict], - use_model: entities.LLMModelInfo, + use_model: requester.RuntimeLLMModel, use_funcs: list[tools_entities.LLMFunction] = None, extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() - args = self.requester_cfg['args'].copy() - args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + args = extra_args.copy() + args["model"] = use_model.model_entity.name if use_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) diff --git a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml index 56deb1df..7290784f 100644 --- a/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/moonshotchatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: 月之暗面 spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/ollamachat.py b/pkg/provider/modelmgr/requesters/ollamachat.py index fa99cfe5..ee331036 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.py +++ b/pkg/provider/modelmgr/requesters/ollamachat.py @@ -22,35 +22,38 @@ REQUESTER_NAME: str = "ollama-chat" class OllamaChatCompletions(requester.LLMAPIRequester): """Ollama平台 ChatCompletion API请求器""" + client: ollama.AsyncClient default_config: dict[str, typing.Any] = { - 'base-url': 'http://127.0.0.1:11434', - 'timeout': 120, + "base_url": "http://127.0.0.1:11434", + "timeout": 120, } async def initialize(self): - os.environ['OLLAMA_HOST'] = self.requester_cfg['base-url'] - self.client = ollama.AsyncClient( - timeout=self.requester_cfg['timeout'] - ) + os.environ["OLLAMA_HOST"] = self.requester_cfg["base_url"] + self.client = ollama.AsyncClient(timeout=self.requester_cfg["timeout"]) - async def _req(self, - args: dict, - ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: - return await self.client.chat( - **args - ) + async def _req( + self, + args: dict, + ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + return await self.client.chat(**args) - async def _closure(self, query: core_entities.Query, req_messages: list[dict], use_model: entities.LLMModelInfo, - user_funcs: list[tools_entities.LLMFunction] = None, - extra_args: dict[str, typing.Any] = {}) -> llm_entities.Message: - args: Any = self.requester_cfg['args'].copy() - args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + async def _closure( + self, + query: core_entities.Query, + req_messages: list[dict], + use_model: requester.RuntimeLLMModel, + user_funcs: list[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, + ) -> llm_entities.Message: + args = extra_args.copy() + args["model"] = use_model.model_entity.name messages: list[dict] = req_messages.copy() for msg in messages: - if 'content' in msg and isinstance(msg["content"], list): + if "content" in msg and isinstance(msg["content"], list): text_content: list = [] image_urls: list = [] for me in msg["content"]: @@ -58,12 +61,16 @@ class OllamaChatCompletions(requester.LLMAPIRequester): text_content.append(me["text"]) elif me["type"] == "image_base64": image_urls.append(me["image_base64"]) - + msg["content"] = "\n".join(text_content) - msg["images"] = [url.split(',')[1] for url in image_urls] - if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict - for tool_call in msg['tool_calls']: - tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments']) + msg["images"] = [url.split(",")[1] for url in image_urls] + if ( + "tool_calls" in msg + ): # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict + for tool_call in msg["tool_calls"]: + tool_call["function"]["arguments"] = json.loads( + tool_call["function"]["arguments"] + ) args["messages"] = messages args["tools"] = [] @@ -77,8 +84,8 @@ class OllamaChatCompletions(requester.LLMAPIRequester): return message async def _make_msg( - self, - chat_completions: ollama.ChatResponse) -> llm_entities.Message: + self, chat_completions: ollama.ChatResponse + ) -> llm_entities.Message: message: ollama.Message = chat_completions.message if message is None: raise ValueError("chat_completions must contain a 'message' field") @@ -86,43 +93,51 @@ class OllamaChatCompletions(requester.LLMAPIRequester): ret_msg: llm_entities.Message = None if message.content is not None: - ret_msg = llm_entities.Message( - role="assistant", - content=message.content - ) + ret_msg = llm_entities.Message(role="assistant", content=message.content) if message.tool_calls is not None and len(message.tool_calls) > 0: tool_calls: list[llm_entities.ToolCall] = [] for tool_call in message.tool_calls: - tool_calls.append(llm_entities.ToolCall( - id=uuid.uuid4().hex, - type="function", - function=llm_entities.FunctionCall( - name=tool_call.function.name, - arguments=json.dumps(tool_call.function.arguments) + tool_calls.append( + llm_entities.ToolCall( + id=uuid.uuid4().hex, + type="function", + function=llm_entities.FunctionCall( + name=tool_call.function.name, + arguments=json.dumps(tool_call.function.arguments), + ), ) - )) + ) ret_msg.tool_calls = tool_calls return ret_msg - async def call( - self, - query: core_entities.Query, - model: entities.LLMModelInfo, - messages: typing.List[llm_entities.Message], - funcs: typing.List[tools_entities.LLMFunction] = None, - extra_args: dict[str, typing.Any] = {}, + async def invoke_llm( + self, + query: core_entities.Query, + model: requester.RuntimeLLMModel, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + extra_args: dict[str, typing.Any] = {}, ) -> llm_entities.Message: req_messages: list = [] for m in messages: msg_dict: dict = m.dict(exclude_none=True) content: Any = msg_dict.get("content") if isinstance(content, list): - if all(isinstance(part, dict) and part.get('type') == 'text' for part in content): + if all( + isinstance(part, dict) and part.get("type") == "text" + for part in content + ): msg_dict["content"] = "\n".join(part["text"] for part in content) req_messages.append(msg_dict) try: - return await self._closure(query, req_messages, model, funcs, extra_args) + return await self._closure( + query=query, + req_messages=req_messages, + use_model=model, + use_funcs=funcs, + extra_args=extra_args, + ) except asyncio.TimeoutError: - raise errors.RequesterError('请求超时') + raise errors.RequesterError("请求超时") diff --git a/pkg/provider/modelmgr/requesters/ollamachat.yaml b/pkg/provider/modelmgr/requesters/ollamachat.yaml index b162e5db..ba915aeb 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.yaml +++ b/pkg/provider/modelmgr/requesters/ollamachat.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: Ollama spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py index a990f809..dd5b9a14 100644 --- a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.py @@ -14,6 +14,6 @@ class SiliconFlowChatCompletions(chatcmpl.OpenAIChatCompletions): client: openai.AsyncClient default_config: dict[str, typing.Any] = { - 'base-url': 'https://api.siliconflow.cn/v1', + 'base_url': 'https://api.siliconflow.cn/v1', 'timeout': 120, } diff --git a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml index 28d534f6..c938b21c 100644 --- a/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/siliconflowchatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: 硅基流动 spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py index fbf88826..9b5505e1 100644 --- a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.py @@ -14,6 +14,6 @@ class VolcArkChatCompletions(chatcmpl.OpenAIChatCompletions): client: openai.AsyncClient default_config: dict[str, typing.Any] = { - 'base-url': 'https://ark.cn-beijing.volces.com/api/v3', + 'base_url': 'https://ark.cn-beijing.volces.com/api/v3', 'timeout': 120, } diff --git a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml index f18c7b2c..56347bc5 100644 --- a/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/volcarkchatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: 火山方舟 spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/xaichatcmpl.py b/pkg/provider/modelmgr/requesters/xaichatcmpl.py index 47c2939a..e08af875 100644 --- a/pkg/provider/modelmgr/requesters/xaichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/xaichatcmpl.py @@ -14,6 +14,6 @@ class XaiChatCompletions(chatcmpl.OpenAIChatCompletions): client: openai.AsyncClient default_config: dict[str, typing.Any] = { - 'base-url': 'https://api.x.ai/v1', + 'base_url': 'https://api.x.ai/v1', 'timeout': 120, } diff --git a/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml index ceda8c0d..604b88c6 100644 --- a/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/xaichatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: xAI spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py index 1e24a5ef..7bbca164 100644 --- a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py +++ b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.py @@ -14,6 +14,6 @@ class ZhipuAIChatCompletions(chatcmpl.OpenAIChatCompletions): client: openai.AsyncClient default_config: dict[str, typing.Any] = { - 'base-url': 'https://open.bigmodel.cn/api/paas/v4', + 'base_url': 'https://open.bigmodel.cn/api/paas/v4', 'timeout': 120, } diff --git a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml index 3d112ca1..20b8b496 100644 --- a/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml +++ b/pkg/provider/modelmgr/requesters/zhipuaichatcmpl.yaml @@ -7,7 +7,7 @@ metadata: zh_CN: 智谱 AI spec: config: - - name: base-url + - name: base_url label: en_US: Base URL zh_CN: 基础 URL diff --git a/pkg/provider/runner.py b/pkg/provider/runner.py index 5a5cf6ef..1762e546 100644 --- a/pkg/provider/runner.py +++ b/pkg/provider/runner.py @@ -27,11 +27,11 @@ class RequestRunner(abc.ABC): ap: app.Application - def __init__(self, ap: app.Application): - self.ap = ap + pipeline_config: dict - async def initialize(self): - pass + def __init__(self, ap: app.Application, pipeline_config: dict): + self.ap = ap + self.pipeline_config = pipeline_config @abc.abstractmethod async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: diff --git a/pkg/provider/runnermgr.py b/pkg/provider/runnermgr.py deleted file mode 100644 index 52e1d8d2..00000000 --- a/pkg/provider/runnermgr.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -from . import runner -from ..core import app - -from .runners import localagent -from .runners import difysvapi -from .runners import dashscopeapi - -class RunnerManager: - - ap: app.Application - - using_runner: runner.RequestRunner - - def __init__(self, ap: app.Application): - self.ap = ap - - async def initialize(self): - - for r in runner.preregistered_runners: - if r.name == self.ap.provider_cfg.data['runner']: - self.using_runner = r(self.ap) - await self.using_runner.initialize() - break - else: - raise ValueError(f"未找到请求运行器: {self.ap.provider_cfg.data['runner']}") - - def get_runner(self) -> runner.RequestRunner: - return self.using_runner diff --git a/pkg/provider/runners/dashscopeapi.py b/pkg/provider/runners/dashscopeapi.py index 0bb09822..d5e6a83d 100644 --- a/pkg/provider/runners/dashscopeapi.py +++ b/pkg/provider/runners/dashscopeapi.py @@ -8,7 +8,7 @@ import re import dashscope from .. import runner -from ...core import entities as core_entities +from ...core import app, entities as core_entities from .. import entities as llm_entities from ...utils import image @@ -29,12 +29,14 @@ class DashScopeAPIRunner(runner.RequestRunner): app_id: str # 应用ID api_key: str # API Key references_quote: str # 引用资料提示(当展示回答来源功能开启时,这个变量会作为引用资料名前的提示,可在provider.json中配置) - biz_params: dict = {} # 工作流应用参数(仅在工作流应用中生效) - async def initialize(self): + def __init__(self, ap: app.Application, pipeline_config: dict): """初始化""" + self.ap = ap + self.pipeline_config = pipeline_config + valid_app_types = ["agent", "workflow"] - self.app_type = self.ap.provider_cfg.data["dashscope-app-api"]["app-type"] + self.app_type = self.pipeline_config["ai"]["dashscope-app-api"]["app-type"] #检查配置文件中使用的应用类型是否支持 if (self.app_type not in valid_app_types): raise DashscopeAPIError( @@ -42,10 +44,9 @@ class DashScopeAPIRunner(runner.RequestRunner): ) #初始化Dashscope 参数配置 - self.app_id = self.ap.provider_cfg.data["dashscope-app-api"][self.app_type]["app-id"] - self.api_key = self.ap.provider_cfg.data["dashscope-app-api"]["api-key"] - self.references_quote = self.ap.provider_cfg.data["dashscope-app-api"][self.app_type]["references_quote"] - self.biz_params = self.ap.provider_cfg.data["dashscope-app-api"]["workflow"]["biz_params"] + self.app_id = self.pipeline_config["ai"]["dashscope-app-api"]["app-id"] + self.api_key = self.pipeline_config["ai"]["dashscope-app-api"]["api-key"] + self.references_quote = self.pipeline_config["ai"]["dashscope-app-api"]["references_quote"] def _replace_references(self, text, references_dict): """阿里云百炼平台的自定义应用支持资料引用,此函数可以将引用标签替换为参考资料""" @@ -169,7 +170,6 @@ class DashScopeAPIRunner(runner.RequestRunner): plain_text, image_ids = await self._preprocess_user_message(query) biz_params = {} - biz_params.update(self.biz_params) biz_params.update(query.variables) #发送对话请求 @@ -220,21 +220,19 @@ class DashScopeAPIRunner(runner.RequestRunner): content=pending_content, ) - - async def run( self, query: core_entities.Query ) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行""" - if self.ap.provider_cfg.data["dashscope-app-api"]["app-type"] == "agent": + if self.app_type == "agent": async for msg in self._agent_messages(query): yield msg - elif self.ap.provider_cfg.data["dashscope-app-api"]["app-type"] == "workflow": + elif self.app_type == "workflow": async for msg in self._workflow_messages(query): yield msg else: raise DashscopeAPIError( - f"不支持的 Dashscope 应用类型: {self.ap.provider_cfg.data['dashscope-app-api']['app-type']}" + f"不支持的 Dashscope 应用类型: {self.app_type}" ) diff --git a/pkg/provider/runners/difysvapi.py b/pkg/provider/runners/difysvapi.py index 81ceddee..f48cbd57 100644 --- a/pkg/provider/runners/difysvapi.py +++ b/pkg/provider/runners/difysvapi.py @@ -10,7 +10,7 @@ import datetime import aiohttp from .. import runner -from ...core import entities as core_entities +from ...core import app, entities as core_entities from .. import entities as llm_entities from ...utils import image @@ -23,24 +23,24 @@ class DifyServiceAPIRunner(runner.RequestRunner): dify_client: client.AsyncDifyServiceClient - async def initialize(self): - """初始化""" + def __init__(self, ap: app.Application, pipeline_config: dict): + self.ap = ap + self.pipeline_config = pipeline_config + valid_app_types = ["chat", "agent", "workflow"] if ( - self.ap.provider_cfg.data["dify-service-api"]["app-type"] + self.pipeline_config["ai"]["dify-service-api"]["app-type"] not in valid_app_types ): raise errors.DifyAPIError( - f"不支持的 Dify 应用类型: {self.ap.provider_cfg.data['dify-service-api']['app-type']}" + f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}" ) - api_key = self.ap.provider_cfg.data["dify-service-api"][ - self.ap.provider_cfg.data["dify-service-api"]["app-type"] - ]["api-key"] + api_key = self.pipeline_config["ai"]["dify-service-api"]["api-key"] self.dify_client = client.AsyncDifyServiceClient( api_key=api_key, - base_url=self.ap.provider_cfg.data["dify-service-api"]["base-url"], + base_url=self.pipeline_config["ai"]["dify-service-api"]["base-url"], ) def _try_convert_thinking(self, resp_text: str) -> str: @@ -48,13 +48,13 @@ class DifyServiceAPIRunner(runner.RequestRunner): if not resp_text.startswith("
Thinking... "): return resp_text - if self.ap.provider_cfg.data["dify-service-api"]["options"]["convert-thinking-tips"] == "original": + if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "original": return resp_text - if self.ap.provider_cfg.data["dify-service-api"]["options"]["convert-thinking-tips"] == "remove": + if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "remove": return re.sub(r'
Thinking... .*?
', '', resp_text, flags=re.DOTALL) - if self.ap.provider_cfg.data["dify-service-api"]["options"]["convert-thinking-tips"] == "plain": + if self.pipeline_config["ai"]["dify-service-api"]["thinking-convert"] == "plain": pattern = r'
Thinking... (.*?)
' thinking_text = re.search(pattern, resp_text, flags=re.DOTALL) content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL) @@ -121,7 +121,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): user=f"{query.session.launcher_type.value}_{query.session.launcher_id}", conversation_id=cov_id, files=files, - timeout=self.ap.provider_cfg.data["dify-service-api"]["chat"]["timeout"], + timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"], ): self.ap.logger.debug("dify-chat-chunk: " + str(chunk)) @@ -177,7 +177,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): response_mode="streaming", conversation_id=cov_id, files=files, - timeout=self.ap.provider_cfg.data["dify-service-api"]["chat"]["timeout"], + timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"], ): self.ap.logger.debug("dify-agent-chunk: " + str(chunk)) @@ -264,7 +264,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): inputs=inputs, user=f"{query.session.launcher_type.value}_{query.session.launcher_id}", files=files, - timeout=self.ap.provider_cfg.data["dify-service-api"]["workflow"]["timeout"], + timeout=self.pipeline_config["ai"]["dify-service-api"]["timeout"], ): self.ap.logger.debug("dify-workflow-chunk: " + str(chunk)) if chunk["event"] in ignored_events: @@ -301,11 +301,7 @@ class DifyServiceAPIRunner(runner.RequestRunner): msg = llm_entities.Message( role="assistant", - content=chunk["data"]["outputs"][ - self.ap.provider_cfg.data["dify-service-api"]["workflow"][ - "output-key" - ] - ], + content=chunk["data"]["outputs"]["summary"], ) yield msg @@ -314,16 +310,16 @@ class DifyServiceAPIRunner(runner.RequestRunner): self, query: core_entities.Query ) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行请求""" - if self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "chat": + if self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "chat": async for msg in self._chat_messages(query): yield msg - elif self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "agent": + elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "agent": async for msg in self._agent_chat_messages(query): yield msg - elif self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "workflow": + elif self.pipeline_config["ai"]["dify-service-api"]["app-type"] == "workflow": async for msg in self._workflow_messages(query): yield msg else: raise errors.DifyAPIError( - f"不支持的 Dify 应用类型: {self.ap.provider_cfg.data['dify-service-api']['app-type']}" + f"不支持的 Dify 应用类型: {self.pipeline_config['ai']['dify-service-api']['app-type']}" ) diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index f05c82e3..68bb2b4f 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -16,14 +16,12 @@ class LocalAgentRunner(runner.RequestRunner): async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行请求 """ - await query.use_model.requester.preprocess(query) - pending_tool_calls = [] req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message] # 首次请求 - msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs) + msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs) yield msg @@ -61,7 +59,7 @@ class LocalAgentRunner(runner.RequestRunner): req_messages.append(err_msg) # 处理完所有调用,再次请求 - msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs) + msg = await query.use_llm_model.requester.invoke_llm(query, query.use_llm_model, req_messages, query.use_funcs) yield msg diff --git a/pkg/provider/session/sessionmgr.py b/pkg/provider/session/sessionmgr.py index 00523472..5143f2bb 100644 --- a/pkg/provider/session/sessionmgr.py +++ b/pkg/provider/session/sessionmgr.py @@ -4,6 +4,7 @@ import asyncio from ...core import app, entities as core_entities from ...plugin import context as plugin_context +from ...provider import entities as provider_entities class SessionManager: @@ -41,17 +42,30 @@ class SessionManager: self.session_list.append(session) return session - async def get_conversation(self, session: core_entities.Session) -> core_entities.Conversation: + async def get_conversation(self, query: core_entities.Query, session: core_entities.Session, prompt_config: list[dict]) -> core_entities.Conversation: """获取对话或创建对话""" if not session.conversations: session.conversations = [] + # set prompt + prompt_messages = [] + + for prompt_message in prompt_config: + prompt_messages.append(provider_entities.Message(**prompt_message)) + + prompt = provider_entities.Prompt( + name="default", + messages=prompt_messages, + ) + if session.using_conversation is None: conversation = core_entities.Conversation( - prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name), + prompt=prompt, messages=[], - use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['model']), + use_llm_model=await self.ap.model_mgr.get_model_by_uuid( + query.pipeline_config['ai']['local-agent']['model'] + ), use_funcs=await self.ap.tool_mgr.get_all_functions( plugin_enabled=True, ), diff --git a/pkg/provider/sysprompt/__init__.py b/pkg/provider/sysprompt/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/provider/sysprompt/entities.py b/pkg/provider/sysprompt/entities.py deleted file mode 100644 index 5442e809..00000000 --- a/pkg/provider/sysprompt/entities.py +++ /dev/null @@ -1,16 +0,0 @@ -from __future__ import annotations - -import typing -import pydantic.v1 as pydantic - -from ...provider import entities - - -class Prompt(pydantic.BaseModel): - """供AI使用的Prompt""" - - name: str - """名称""" - - messages: list[entities.Message] - """消息列表""" diff --git a/pkg/provider/sysprompt/loader.py b/pkg/provider/sysprompt/loader.py deleted file mode 100644 index 855728e2..00000000 --- a/pkg/provider/sysprompt/loader.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations -import abc -import typing - -from ...core import app -from . import entities - - -preregistered_loaders: list[typing.Type[PromptLoader]] = [] - -def loader_class(name: str): - - def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]: - cls.name = name - preregistered_loaders.append(cls) - return cls - - return decorator - - -class PromptLoader(metaclass=abc.ABCMeta): - """Prompt加载器抽象类 - """ - name: str - - ap: app.Application - - prompts: list[entities.Prompt] - - def __init__(self, ap: app.Application): - self.ap = ap - self.prompts = [] - - async def initialize(self): - pass - - @abc.abstractmethod - async def load(self): - """加载Prompt,存放到prompts列表中 - """ - raise NotImplementedError - - def get_prompts(self) -> list[entities.Prompt]: - """获取Prompt列表 - """ - return self.prompts diff --git a/pkg/provider/sysprompt/loaders/__init__.py b/pkg/provider/sysprompt/loaders/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pkg/provider/sysprompt/loaders/scenario.py b/pkg/provider/sysprompt/loaders/scenario.py deleted file mode 100644 index f907a51c..00000000 --- a/pkg/provider/sysprompt/loaders/scenario.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import annotations - -import json -import os - -from .. import loader -from .. import entities -from ....provider import entities as llm_entities - - -@loader.loader_class("full-scenario") -class ScenarioPromptLoader(loader.PromptLoader): - """加载scenario目录下的json""" - - async def load(self): - """加载Prompt - """ - for file in os.listdir("data/scenario"): - with open("data/scenario/{}".format(file), "r", encoding="utf-8") as f: - file_str = f.read() - file_name = file.split(".")[0] - file_json = json.loads(file_str) - messages = [] - for msg in file_json["prompt"]: - role = 'system' - if "role" in msg: - role = msg['role'] - messages.append( - llm_entities.Message( - role=role, - content=msg['content'], - ) - ) - prompt = entities.Prompt( - name=file_name, - messages=messages - ) - self.prompts.append(prompt) - \ No newline at end of file diff --git a/pkg/provider/sysprompt/loaders/single.py b/pkg/provider/sysprompt/loaders/single.py deleted file mode 100644 index 3ac9c262..00000000 --- a/pkg/provider/sysprompt/loaders/single.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations -import os - -from .. import loader -from .. import entities -from ....provider import entities as llm_entities - - -@loader.loader_class("normal") -class SingleSystemPromptLoader(loader.PromptLoader): - """配置文件中的单条system prompt的prompt加载器 - """ - - async def load(self): - """加载Prompt - """ - - for name, cnt in self.ap.provider_cfg.data['prompt'].items(): - prompt = entities.Prompt( - name=name, - messages=[ - llm_entities.Message( - role='system', - content=cnt - ) - ] - ) - self.prompts.append(prompt) - - for file in os.listdir("data/prompts"): - with open("data/prompts/{}".format(file), "r", encoding="utf-8") as f: - file_str = f.read() - file_name = file.split(".")[0] - prompt = entities.Prompt( - name=file_name, - messages=[ - llm_entities.Message( - role='system', - content=file_str - ) - ] - ) - self.prompts.append(prompt) diff --git a/pkg/provider/sysprompt/sysprompt.py b/pkg/provider/sysprompt/sysprompt.py deleted file mode 100644 index c7695f5a..00000000 --- a/pkg/provider/sysprompt/sysprompt.py +++ /dev/null @@ -1,56 +0,0 @@ -from __future__ import annotations - -from ...core import app -from . import loader -from .loaders import single, scenario - - -class PromptManager: - """Prompt管理器 - """ - - ap: app.Application - - loader_inst: loader.PromptLoader - - default_prompt: str = 'default' - - def __init__(self, ap: app.Application): - self.ap = ap - - async def initialize(self): - - mode_name = self.ap.provider_cfg.data['prompt-mode'] - - loader_class = None - - for loader_cls in loader.preregistered_loaders: - if loader_cls.name == mode_name: - loader_class = loader_cls - break - else: - raise ValueError(f'未知的 Prompt 加载器: {mode_name}') - - self.loader_inst: loader.PromptLoader = loader_class(self.ap) - - await self.loader_inst.initialize() - await self.loader_inst.load() - - def get_all_prompts(self) -> list[loader.entities.Prompt]: - """获取所有Prompt - """ - return self.loader_inst.get_prompts() - - async def get_prompt(self, name: str) -> loader.entities.Prompt: - """获取Prompt - """ - for prompt in self.get_all_prompts(): - if prompt.name == name: - return prompt - - async def get_prompt_by_prefix(self, prefix: str) -> loader.entities.Prompt: - """通过前缀获取Prompt - """ - for prompt in self.get_all_prompts(): - if prompt.name.startswith(prefix): - return prompt diff --git a/templates/metadata/pipeline/ai.yaml b/templates/metadata/pipeline/ai.yaml index 38d579c0..8b7959f6 100644 --- a/templates/metadata/pipeline/ai.yaml +++ b/templates/metadata/pipeline/ai.yaml @@ -42,7 +42,7 @@ stages: zh_CN: 模型 type: select required: true - scope: llm-model + scope: /provider/models/llm - name: max-round label: en_US: Max Round @@ -56,9 +56,14 @@ stages: zh_CN: 提示词 type: array required: true - default: [] items: - type: string + type: object + properties: + role: + type: string + default: user + content: + type: string - name: dify-service-api label: en_US: Dify Service API diff --git a/templates/metadata/pipeline/safety.yaml b/templates/metadata/pipeline/safety.yaml index d19913af..ba59e067 100644 --- a/templates/metadata/pipeline/safety.yaml +++ b/templates/metadata/pipeline/safety.yaml @@ -46,7 +46,7 @@ stages: zh_CN: 窗口长度(秒) type: integer required: true - default: 10 + default: 60 - name: limitation label: en_US: Limitation @@ -54,3 +54,19 @@ stages: type: integer required: true default: 60 + - name: strategy + label: + en_US: Strategy + zh_CN: 策略 + type: select + required: true + default: drop + options: + - name: drop + label: + en_US: Drop + zh_CN: 丢弃 + - name: wait + label: + en_US: Wait + zh_CN: 等待 \ No newline at end of file